use std::collections::HashMap; use std::sync::Arc; use futures::StreamExt; use serde::Deserialize; use serde::Serialize; use thiserror::Error; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ServiceKind { Aggregate, Projection, Runner, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct RoutingConfig { pub revision: u64, pub aggregate_placement: HashMap, pub projection_placement: HashMap, pub runner_placement: HashMap, pub aggregate_shards: HashMap>, pub projection_shards: HashMap>, pub runner_shards: HashMap>, } impl RoutingConfig { pub fn empty() -> Self { Self { revision: 0, aggregate_placement: HashMap::new(), projection_placement: HashMap::new(), runner_placement: HashMap::new(), aggregate_shards: HashMap::new(), projection_shards: HashMap::new(), runner_shards: HashMap::new(), } } } #[derive(Debug, Clone, Serialize)] pub struct RoutingTable { pub revision: u64, aggregate_placement: HashMap, projection_placement: HashMap, runner_placement: HashMap, aggregate_shards: HashMap>, projection_shards: HashMap>, runner_shards: HashMap>, } impl From for RoutingTable { fn from(value: RoutingConfig) -> Self { Self { revision: value.revision, aggregate_placement: value.aggregate_placement, projection_placement: value.projection_placement, runner_placement: value.runner_placement, aggregate_shards: value.aggregate_shards, projection_shards: value.projection_shards, runner_shards: value.runner_shards, } } } #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum RoutingError { #[error("unknown tenant")] UnknownTenant, #[error("missing shard directory entry")] MissingShard, #[error("no endpoints for shard")] EmptyShard, } #[derive(Clone)] pub struct RouterState { table: Arc>>, source: Arc, } impl std::fmt::Debug for RouterState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RouterState").finish_non_exhaustive() } } impl RouterState { pub async fn new(source: Arc) -> Result { let cfg = source.load().await?; Ok(Self { table: Arc::new(tokio::sync::RwLock::new(Arc::new(cfg.into()))), source, }) } pub async fn snapshot(&self) -> Arc { self.table.read().await.clone() } pub async fn reload(&self) -> Result<(), RoutingSourceError> { let cfg = self.source.load().await?; let next = Arc::new(RoutingTable::from(cfg)); *self.table.write().await = next; Ok(()) } pub fn start_watcher(&self) -> tokio::task::JoinHandle<()> { let this = self.clone(); tokio::spawn(async move { let mut stream = match this.source.watch().await { Ok(s) => s, Err(_) => return, }; while let Some(msg) = stream.next().await { if msg.is_err() { continue; } let _ = this.reload().await; } }) } pub async fn resolve( &self, tenant_id: &str, kind: ServiceKind, ) -> Result { let table = self.snapshot().await; let result = table.resolve(tenant_id, kind); metrics::counter!( "gateway_routing_resolutions_total", "kind" => kind_label(kind), "result" => if result.is_ok() { "ok" } else { "err" } ) .increment(1); result } } fn kind_label(kind: ServiceKind) -> &'static str { match kind { ServiceKind::Aggregate => "aggregate", ServiceKind::Projection => "projection", ServiceKind::Runner => "runner", } } impl RoutingTable { pub fn resolve(&self, tenant_id: &str, kind: ServiceKind) -> Result { let shard_id = match kind { ServiceKind::Aggregate => self.aggregate_placement.get(tenant_id), ServiceKind::Projection => self.projection_placement.get(tenant_id), ServiceKind::Runner => self.runner_placement.get(tenant_id), } .ok_or(RoutingError::UnknownTenant)?; let endpoints = match kind { ServiceKind::Aggregate => self.aggregate_shards.get(shard_id), ServiceKind::Projection => self.projection_shards.get(shard_id), ServiceKind::Runner => self.runner_shards.get(shard_id), } .ok_or(RoutingError::MissingShard)?; endpoints.first().cloned().ok_or(RoutingError::EmptyShard) } } #[derive(Debug, Error)] pub enum RoutingSourceError { #[error("source error: {0}")] Source(String), #[error("decode error: {0}")] Decode(String), } #[async_trait::async_trait] pub trait RoutingSource: Send + Sync { async fn load(&self) -> Result; async fn watch( &self, ) -> Result< std::pin::Pin> + Send>>, RoutingSourceError, >; } #[derive(Clone)] pub struct FixedSource { cfg: RoutingConfig, } impl FixedSource { pub fn new(cfg: RoutingConfig) -> Self { Self { cfg } } } #[async_trait::async_trait] impl RoutingSource for FixedSource { async fn load(&self) -> Result { Ok(self.cfg.clone()) } async fn watch( &self, ) -> Result< std::pin::Pin> + Send>>, RoutingSourceError, > { Ok(Box::pin(futures::stream::empty())) } } #[derive(Clone)] pub struct StaticFileSource { path: String, } impl StaticFileSource { pub fn new(path: impl Into) -> Self { Self { path: path.into() } } } #[async_trait::async_trait] impl RoutingSource for StaticFileSource { async fn load(&self) -> Result { let raw = tokio::fs::read_to_string(&self.path) .await .map_err(|e| RoutingSourceError::Source(e.to_string()))?; if self.path.ends_with(".json") { serde_json::from_str::(&raw) .map_err(|e| RoutingSourceError::Decode(e.to_string())) } else { let yaml: serde_yaml::Value = serde_yaml::from_str(&raw) .map_err(|e| RoutingSourceError::Decode(e.to_string()))?; let json = serde_json::to_value(yaml) .map_err(|e| RoutingSourceError::Decode(e.to_string()))?; serde_json::from_value::(json) .map_err(|e| RoutingSourceError::Decode(e.to_string())) } } async fn watch( &self, ) -> Result< std::pin::Pin> + Send>>, RoutingSourceError, > { Ok(Box::pin(futures::stream::empty())) } } #[derive(Clone)] pub struct NatsKvSource { kv: async_nats::jetstream::kv::Store, key: String, } impl NatsKvSource { pub async fn connect( nats_url: impl Into, bucket: impl Into, key: impl Into, ) -> Result { let nats_url = nats_url.into(); let bucket = bucket.into(); let key = key.into(); let client = async_nats::connect(nats_url) .await .map_err(|e| RoutingSourceError::Source(e.to_string()))?; let jetstream = async_nats::jetstream::new(client); let kv = match jetstream.get_key_value(&bucket).await { Ok(kv) => kv, Err(_) => jetstream .create_key_value(async_nats::jetstream::kv::Config { bucket: bucket.clone(), ..Default::default() }) .await .map_err(|e| RoutingSourceError::Source(e.to_string()))?, }; Ok(Self { kv, key }) } } #[async_trait::async_trait] impl RoutingSource for NatsKvSource { async fn load(&self) -> Result { let entry = self .kv .entry(&self.key) .await .map_err(|e| RoutingSourceError::Source(e.to_string()))?; let Some(entry) = entry else { return Ok(RoutingConfig::empty()); }; serde_json::from_slice::(&entry.value) .map_err(|e| RoutingSourceError::Decode(e.to_string())) } async fn watch( &self, ) -> Result< std::pin::Pin> + Send>>, RoutingSourceError, > { let key = self.key.clone(); let watch = self .kv .watch(&key) .await .map_err(|e| RoutingSourceError::Source(e.to_string()))?; Ok(Box::pin(watch.filter_map(|entry| async move { match entry { Ok(entry) => match entry.operation { async_nats::jetstream::kv::Operation::Put => Some(Ok(())), async_nats::jetstream::kv::Operation::Delete | async_nats::jetstream::kv::Operation::Purge => None, }, Err(e) => Some(Err(RoutingSourceError::Source(e.to_string()))), } }))) } } #[cfg(test)] mod tests { use super::*; fn assert_send_sync() {} #[test] fn router_state_is_send_sync() { assert_send_sync::(); } #[tokio::test] async fn resolves_endpoints_for_tenant_service_kind() { let cfg = RoutingConfig { revision: 1, aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]), projection_placement: HashMap::from([("t1".to_string(), "p".to_string())]), runner_placement: HashMap::from([("t1".to_string(), "r".to_string())]), aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a".to_string()])]), projection_shards: HashMap::from([("p".to_string(), vec!["http://p".to_string()])]), runner_shards: HashMap::from([("r".to_string(), vec!["http://r".to_string()])]), }; let source: Arc = Arc::new(TestSource::new(cfg)); let router = RouterState::new(source).await.unwrap(); assert_eq!( router.resolve("t1", ServiceKind::Aggregate).await.unwrap(), "http://a" ); assert_eq!( router.resolve("t1", ServiceKind::Projection).await.unwrap(), "http://p" ); assert_eq!( router.resolve("t1", ServiceKind::Runner).await.unwrap(), "http://r" ); } #[tokio::test] async fn unknown_tenant_is_typed_error() { let source: Arc = Arc::new(TestSource::new(RoutingConfig::empty())); let router = RouterState::new(source).await.unwrap(); let err = router .resolve("missing", ServiceKind::Aggregate) .await .unwrap_err(); assert_eq!(err, RoutingError::UnknownTenant); } #[tokio::test] async fn hot_reload_swaps_table_atomically() { let cfg1 = RoutingConfig { revision: 1, aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]), projection_placement: HashMap::new(), runner_placement: HashMap::new(), aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a1".to_string()])]), projection_shards: HashMap::new(), runner_shards: HashMap::new(), }; let cfg2 = RoutingConfig { revision: 2, aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]), projection_placement: HashMap::new(), runner_placement: HashMap::new(), aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a2".to_string()])]), projection_shards: HashMap::new(), runner_shards: HashMap::new(), }; let test_source = Arc::new(TestSource::new(cfg1)); let router = RouterState::new(test_source.clone()).await.unwrap(); let before = router.resolve("t1", ServiceKind::Aggregate).await.unwrap(); assert_eq!(before, "http://a1"); test_source.set(cfg2).await; router.reload().await.unwrap(); let after = router.resolve("t1", ServiceKind::Aggregate).await.unwrap(); assert_eq!(after, "http://a2"); } #[derive(Clone)] struct TestSource { cfg: Arc>, } impl TestSource { fn new(cfg: RoutingConfig) -> Self { Self { cfg: Arc::new(tokio::sync::RwLock::new(cfg)), } } async fn set(&self, cfg: RoutingConfig) { *self.cfg.write().await = cfg; } } #[async_trait::async_trait] impl RoutingSource for TestSource { async fn load(&self) -> Result { Ok(self.cfg.read().await.clone()) } async fn watch( &self, ) -> Result< std::pin::Pin> + Send>>, RoutingSourceError, > { Ok(Box::pin(futures::stream::empty())) } } }