use futures::StreamExt; use serde::{Deserialize, Serialize}; use thiserror::Error; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TenantPlacementConfig { pub virtual_nodes_per_node: usize, pub nodes: Vec, pub tenants: std::collections::HashMap, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct NodePlacement { pub node_id: String, pub tenant_range: String, } pub fn placement_constraint_for_tenant_range(tenant_range: &str) -> String { format!("node.labels.tenant_range == {}", tenant_range) } pub fn placement_constraints_for_node(node: &NodePlacement) -> Vec { vec![placement_constraint_for_tenant_range(&node.tenant_range)] } #[derive(Debug, Clone, PartialEq, Eq)] pub struct MigrationPlan { pub tenant_id: String, pub from_node: String, pub to_node: String, pub actions: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum MigrationAction { DrainTenant { tenant_id: String }, UpdatePlacement { tenant_id: String, node_id: String }, ReloadConfig, } pub fn plan_graceful_tenant_migration( tenant_id: impl Into, from_node: impl Into, to_node: impl Into, ) -> MigrationPlan { let tenant_id = tenant_id.into(); let from_node = from_node.into(); let to_node = to_node.into(); MigrationPlan { tenant_id: tenant_id.clone(), from_node, to_node: to_node.clone(), actions: vec![ MigrationAction::DrainTenant { tenant_id: tenant_id.clone(), }, MigrationAction::UpdatePlacement { tenant_id, node_id: to_node, }, MigrationAction::ReloadConfig, ], } } #[derive(Debug, Error)] pub enum TenantPlacementKvError { #[error("NATS connection error: {0}")] Connection(String), #[error("KV error: {0}")] Kv(String), #[error("Config parse error: {0}")] Parse(String), #[error("Unsupported key operation")] UnsupportedOperation, } #[derive(Debug, Clone)] pub struct TenantPlacementKvClient { kv: async_nats::jetstream::kv::Store, } impl TenantPlacementKvClient { pub async fn connect( nats_url: impl Into, bucket: impl Into, ) -> Result { Self::connect_with_timeout(nats_url, bucket, std::time::Duration::from_secs(2)).await } pub async fn connect_with_timeout( nats_url: impl Into, bucket: impl Into, timeout: std::time::Duration, ) -> Result { let nats_url = nats_url.into(); let bucket = bucket.into(); let client = tokio::time::timeout(timeout, async_nats::connect(nats_url)) .await .map_err(|_| TenantPlacementKvError::Connection("connect timeout".to_string()))? .map_err(|e| TenantPlacementKvError::Connection(e.to_string()))?; let jetstream = async_nats::jetstream::new(client); let kv = match jetstream.get_key_value(&bucket).await { Ok(kv) => kv, Err(_) => jetstream .create_key_value(async_nats::jetstream::kv::Config { bucket: bucket.clone(), ..Default::default() }) .await .map_err(|e| TenantPlacementKvError::Kv(e.to_string()))?, }; Ok(Self { kv }) } pub async fn get_json( &self, key: &str, ) -> Result, TenantPlacementKvError> { let entry = self .kv .entry(key) .await .map_err(|e| TenantPlacementKvError::Kv(e.to_string()))?; match entry { Some(entry) => serde_json::from_slice::(&entry.value) .map(Some) .map_err(|e| TenantPlacementKvError::Parse(e.to_string())), None => Ok(None), } } pub async fn put_json( &self, key: &str, value: &serde_json::Value, ) -> Result<(), TenantPlacementKvError> { let bytes = serde_json::to_vec(value).map_err(|e| TenantPlacementKvError::Parse(e.to_string()))?; self.kv .put(key, bytes.into()) .await .map_err(|e| TenantPlacementKvError::Kv(e.to_string()))?; Ok(()) } pub async fn watch_json( &self, pattern: &str, ) -> Result< std::pin::Pin< Box< dyn futures::Stream> + Send, >, >, TenantPlacementKvError, > { let watch = self .kv .watch(pattern) .await .map_err(|e| TenantPlacementKvError::Kv(e.to_string()))?; Ok(Box::pin(watch.filter_map(|entry| async move { match entry { Ok(entry) => match entry.operation { async_nats::jetstream::kv::Operation::Put => { match serde_json::from_slice::(&entry.value) { Ok(v) => Some(Ok(v)), Err(e) => Some(Err(TenantPlacementKvError::Parse(e.to_string()))), } } async_nats::jetstream::kv::Operation::Delete | async_nats::jetstream::kv::Operation::Purge => None, }, Err(e) => Some(Err(TenantPlacementKvError::Kv(e.to_string()))), } }))) } pub async fn load_config_with_fallback( nats_url: impl Into, bucket: impl Into, key: &str, fallback_path: &str, ) -> Result { let try_kv = match Self::connect_with_timeout( nats_url, bucket, std::time::Duration::from_millis(300), ) .await { Ok(client) => match client.get_json(key).await { Ok(Some(v)) => Ok(v), Ok(None) => Err(TenantPlacementKvError::Kv("missing key".to_string())), Err(e) => Err(e), }, Err(e) => Err(e), }; match try_kv { Ok(v) => Ok(v), Err(_) => { let raw = std::fs::read_to_string(fallback_path) .map_err(|e| TenantPlacementKvError::Kv(e.to_string()))?; if fallback_path.ends_with(".json") { serde_json::from_str(&raw) .map_err(|e| TenantPlacementKvError::Parse(e.to_string())) } else { let yaml: serde_yaml::Value = serde_yaml::from_str(&raw) .map_err(|e| TenantPlacementKvError::Parse(e.to_string()))?; let json = serde_json::to_value(yaml) .map_err(|e| TenantPlacementKvError::Parse(e.to_string()))?; Ok(json) } } } } } #[cfg(test)] mod tests { use super::*; use futures::StreamExt; #[test] fn stack_file_is_valid_yaml() { let raw = std::fs::read_to_string("../swarm/stacks/platform.yml").unwrap(); let _: serde_yaml::Value = serde_yaml::from_str(&raw).unwrap(); } #[test] fn stack_services_count() { let raw = std::fs::read_to_string("../swarm/stacks/platform.yml").unwrap(); let doc: serde_yaml::Value = serde_yaml::from_str(&raw).unwrap(); let services = doc.get("services").and_then(|v| v.as_mapping()).unwrap(); assert!(services.contains_key(serde_yaml::Value::String("nats".to_string()))); assert!(services.contains_key(serde_yaml::Value::String("gateway".to_string()))); assert!(services.contains_key(serde_yaml::Value::String("aggregate".to_string()))); } #[test] fn tenant_placement_config_loads() { let raw = std::fs::read_to_string("../swarm/tenant-placement.yaml").unwrap(); let cfg: TenantPlacementConfig = serde_yaml::from_str(&raw).unwrap(); assert_eq!(cfg.virtual_nodes_per_node, 200); assert!(cfg.nodes.iter().any(|n| n.node_id == "node-a")); assert_eq!(cfg.tenants.get("tenant-a").unwrap(), "node-a"); } #[test] fn placement_constraint_generated_correctly() { let node = NodePlacement { node_id: "node-a".to_string(), tenant_range: "00-3f".to_string(), }; let constraints = placement_constraints_for_node(&node); assert_eq!(constraints, vec!["node.labels.tenant_range == 00-3f"]); } #[test] fn graceful_tenant_migration_plan_is_ordered() { let plan = plan_graceful_tenant_migration("tenant-a", "node-a", "node-b"); assert_eq!(plan.tenant_id, "tenant-a"); assert_eq!( plan.actions, vec![ MigrationAction::DrainTenant { tenant_id: "tenant-a".to_string(), }, MigrationAction::UpdatePlacement { tenant_id: "tenant-a".to_string(), node_id: "node-b".to_string(), }, MigrationAction::ReloadConfig, ] ); } #[tokio::test] async fn tenant_placement_kv_falls_back_to_local_file() { let tmp = tempfile::tempdir().unwrap(); let path = tmp.path().join("placement.yaml"); std::fs::write( &path, r#" virtual_nodes_per_node: 100 nodes: - node_id: "node-a" tenant_range: "00-ff" tenants: tenant-a: "node-a" "#, ) .unwrap(); let cfg = TenantPlacementKvClient::load_config_with_fallback( "nats://127.0.0.1:1", "TENANT_PLACEMENT", "placement", path.to_string_lossy().as_ref(), ) .await .unwrap(); assert_eq!(cfg["virtual_nodes_per_node"], 100); assert_eq!(cfg["tenants"]["tenant-a"], "node-a"); } #[tokio::test] async fn tenant_placement_kv_watch_returns_stream() { let result = TenantPlacementKvClient::connect_with_timeout( "nats://127.0.0.1:1", "TENANT_PLACEMENT", std::time::Duration::from_millis(50), ) .await; assert!(result.is_err()); let mut stream = futures::stream::empty::>(); assert!(stream.next().await.is_none()); } }