use crate::observability::Observability; use crate::types::{AggregateError, TenantId}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TenantStatus { pub tenant_id: TenantId, pub hosted: bool, pub accepting: bool, pub draining: bool, pub in_flight: u64, } pub struct TenantPlacementManager { hosted: RwLock>, draining: RwLock>, in_flight: RwLock>, observability: Arc, } impl TenantPlacementManager { pub fn new(observability: Arc) -> Self { Self { hosted: RwLock::new(HashSet::new()), draining: RwLock::new(HashSet::new()), in_flight: RwLock::new(HashMap::new()), observability, } } pub async fn set_hosted_tenants(&self, tenant_ids: impl IntoIterator) { let mut hosted = self.hosted.write().await; hosted.clear(); hosted.extend(tenant_ids); } pub async fn apply_placement_map(&self, shard_id: &str, placement: &HashMap) { let tenants = placement .iter() .filter_map(|(tenant_id, assigned)| { if assigned == shard_id { Some(tenant_id.clone()) } else { None } }) .collect::>(); self.set_hosted_tenants(tenants).await; } pub async fn is_hosted(&self, tenant_id: &TenantId) -> bool { if tenant_id.as_str().is_empty() { return true; } self.hosted.read().await.contains(tenant_id.as_str()) } pub async fn is_draining(&self, tenant_id: &TenantId) -> bool { self.draining.read().await.contains(tenant_id.as_str()) } pub async fn begin_command( self: &Arc, tenant_id: &TenantId, ) -> Result { if !self.is_hosted(tenant_id).await { return Err(AggregateError::TenantNotHosted { tenant_id: tenant_id.clone(), }); } if self.is_draining(tenant_id).await { return Err(AggregateError::TenantDraining { tenant_id: tenant_id.clone(), }); } let mut map = self.in_flight.write().await; let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0); *counter += 1; let value = *counter; drop(map); self.observability .metrics() .set_in_flight(tenant_id.as_str(), value); Ok(TenantCommandGuard { tenant_id: tenant_id.clone(), manager: self.clone(), }) } pub async fn drain_tenant(&self, tenant_id: &TenantId) { if tenant_id.as_str().is_empty() { return; } let mut draining = self.draining.write().await; draining.insert(tenant_id.as_str().to_string()); } pub async fn undrain_tenant(&self, tenant_id: &TenantId) { let mut draining = self.draining.write().await; draining.remove(tenant_id.as_str()); } pub async fn wait_drained(&self, tenant_id: &TenantId) { loop { let in_flight = self .in_flight .read() .await .get(tenant_id.as_str()) .copied() .unwrap_or(0); if in_flight == 0 { break; } tokio::time::sleep(std::time::Duration::from_millis(10)).await; } } pub async fn wait_drained_with_timeout( &self, tenant_id: &TenantId, timeout: std::time::Duration, ) -> bool { let deadline = tokio::time::Instant::now() + timeout; loop { let in_flight = self .in_flight .read() .await .get(tenant_id.as_str()) .copied() .unwrap_or(0); if in_flight == 0 { return true; } if tokio::time::Instant::now() >= deadline { return false; } tokio::time::sleep(std::time::Duration::from_millis(10)).await; } } pub async fn tenant_status(&self, tenant_id: &TenantId) -> TenantStatus { let hosted = self.is_hosted(tenant_id).await; let draining = self.is_draining(tenant_id).await; let in_flight = self .in_flight .read() .await .get(tenant_id.as_str()) .copied() .unwrap_or(0); TenantStatus { tenant_id: tenant_id.clone(), hosted, accepting: hosted && !draining, draining, in_flight, } } pub async fn hosted_tenants(&self) -> Vec { let hosted = self.hosted.read().await; hosted.iter().map(TenantId::new).collect() } pub async fn all_statuses(&self) -> Vec { let hosted = self.hosted.read().await.clone(); let draining = self.draining.read().await.clone(); let in_flight = self.in_flight.read().await.clone(); hosted .into_iter() .map(|id| { let tenant_id = TenantId::new(id.clone()); let d = draining.contains(&id); let f = in_flight.get(&id).copied().unwrap_or(0); TenantStatus { tenant_id, hosted: true, accepting: !d, draining: d, in_flight: f, } }) .collect() } async fn finish_command(&self, tenant_id: &TenantId) { let mut map = self.in_flight.write().await; let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0); if *counter > 0 { *counter -= 1; } let value = *counter; drop(map); self.observability .metrics() .set_in_flight(tenant_id.as_str(), value); } } pub struct TenantCommandGuard { tenant_id: TenantId, manager: Arc, } impl Drop for TenantCommandGuard { fn drop(&mut self) { let tenant_id = self.tenant_id.clone(); let manager = self.manager.clone(); tokio::spawn(async move { manager.finish_command(&tenant_id).await; }); } } #[cfg(test)] mod tests { use super::*; use crate::observability::Observability; #[tokio::test] async fn placement_rejects_unhosted_tenant() { let obs = Arc::new(Observability::default()); let mgr = Arc::new(TenantPlacementManager::new(obs)); mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await; let err = match mgr.begin_command(&TenantId::new("tenant-b")).await { Ok(_) => panic!("expected error"), Err(e) => e, }; assert!(matches!(err, AggregateError::TenantNotHosted { .. })); } #[tokio::test] async fn drain_blocks_new_commands_until_in_flight_zero() { let obs = Arc::new(Observability::default()); let mgr = Arc::new(TenantPlacementManager::new(obs)); mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await; let guard = mgr.begin_command(&TenantId::new("tenant-a")).await.unwrap(); mgr.drain_tenant(&TenantId::new("tenant-a")).await; let err = match mgr.begin_command(&TenantId::new("tenant-a")).await { Ok(_) => panic!("expected error"), Err(e) => e, }; assert!(matches!(err, AggregateError::TenantDraining { .. })); drop(guard); mgr.wait_drained(&TenantId::new("tenant-a")).await; let err = match mgr.begin_command(&TenantId::new("tenant-a")).await { Ok(_) => panic!("expected error"), Err(e) => e, }; assert!(matches!(err, AggregateError::TenantDraining { .. })); } }