Files
cloudlysis/aggregate/src/placement.rs
Vlad Durnea 1298d9a3df
Some checks failed
ci / rust (push) Failing after 2m34s
ci / ui (push) Failing after 30s
Monorepo consolidation: workspace, shared types, transport plans, docker/swam assets
2026-03-30 11:40:42 +03:00

268 lines
8.0 KiB
Rust

use crate::observability::Observability;
use crate::types::{AggregateError, TenantId};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TenantStatus {
pub tenant_id: TenantId,
pub hosted: bool,
pub accepting: bool,
pub draining: bool,
pub in_flight: u64,
}
pub struct TenantPlacementManager {
hosted: RwLock<HashSet<String>>,
draining: RwLock<HashSet<String>>,
in_flight: RwLock<HashMap<String, u64>>,
observability: Arc<Observability>,
}
impl TenantPlacementManager {
pub fn new(observability: Arc<Observability>) -> Self {
Self {
hosted: RwLock::new(HashSet::new()),
draining: RwLock::new(HashSet::new()),
in_flight: RwLock::new(HashMap::new()),
observability,
}
}
pub async fn set_hosted_tenants(&self, tenant_ids: impl IntoIterator<Item = String>) {
let mut hosted = self.hosted.write().await;
hosted.clear();
hosted.extend(tenant_ids);
}
pub async fn apply_placement_map(&self, shard_id: &str, placement: &HashMap<String, String>) {
let tenants = placement
.iter()
.filter_map(|(tenant_id, assigned)| {
if assigned == shard_id {
Some(tenant_id.clone())
} else {
None
}
})
.collect::<Vec<_>>();
self.set_hosted_tenants(tenants).await;
}
pub async fn is_hosted(&self, tenant_id: &TenantId) -> bool {
if tenant_id.as_str().is_empty() {
return true;
}
self.hosted.read().await.contains(tenant_id.as_str())
}
pub async fn is_draining(&self, tenant_id: &TenantId) -> bool {
self.draining.read().await.contains(tenant_id.as_str())
}
pub async fn begin_command(
self: &Arc<Self>,
tenant_id: &TenantId,
) -> Result<TenantCommandGuard, AggregateError> {
if !self.is_hosted(tenant_id).await {
return Err(AggregateError::TenantNotHosted {
tenant_id: tenant_id.clone(),
});
}
if self.is_draining(tenant_id).await {
return Err(AggregateError::TenantDraining {
tenant_id: tenant_id.clone(),
});
}
let mut map = self.in_flight.write().await;
let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0);
*counter += 1;
let value = *counter;
drop(map);
self.observability
.metrics()
.set_in_flight(tenant_id.as_str(), value);
Ok(TenantCommandGuard {
tenant_id: tenant_id.clone(),
manager: self.clone(),
})
}
pub async fn drain_tenant(&self, tenant_id: &TenantId) {
if tenant_id.as_str().is_empty() {
return;
}
let mut draining = self.draining.write().await;
draining.insert(tenant_id.as_str().to_string());
}
pub async fn undrain_tenant(&self, tenant_id: &TenantId) {
let mut draining = self.draining.write().await;
draining.remove(tenant_id.as_str());
}
pub async fn wait_drained(&self, tenant_id: &TenantId) {
loop {
let in_flight = self
.in_flight
.read()
.await
.get(tenant_id.as_str())
.copied()
.unwrap_or(0);
if in_flight == 0 {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
pub async fn wait_drained_with_timeout(
&self,
tenant_id: &TenantId,
timeout: std::time::Duration,
) -> bool {
let deadline = tokio::time::Instant::now() + timeout;
loop {
let in_flight = self
.in_flight
.read()
.await
.get(tenant_id.as_str())
.copied()
.unwrap_or(0);
if in_flight == 0 {
return true;
}
if tokio::time::Instant::now() >= deadline {
return false;
}
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
pub async fn tenant_status(&self, tenant_id: &TenantId) -> TenantStatus {
let hosted = self.is_hosted(tenant_id).await;
let draining = self.is_draining(tenant_id).await;
let in_flight = self
.in_flight
.read()
.await
.get(tenant_id.as_str())
.copied()
.unwrap_or(0);
TenantStatus {
tenant_id: tenant_id.clone(),
hosted,
accepting: hosted && !draining,
draining,
in_flight,
}
}
pub async fn hosted_tenants(&self) -> Vec<TenantId> {
let hosted = self.hosted.read().await;
hosted.iter().map(TenantId::new).collect()
}
pub async fn all_statuses(&self) -> Vec<TenantStatus> {
let hosted = self.hosted.read().await.clone();
let draining = self.draining.read().await.clone();
let in_flight = self.in_flight.read().await.clone();
hosted
.into_iter()
.map(|id| {
let tenant_id = TenantId::new(id.clone());
let d = draining.contains(&id);
let f = in_flight.get(&id).copied().unwrap_or(0);
TenantStatus {
tenant_id,
hosted: true,
accepting: !d,
draining: d,
in_flight: f,
}
})
.collect()
}
async fn finish_command(&self, tenant_id: &TenantId) {
let mut map = self.in_flight.write().await;
let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0);
if *counter > 0 {
*counter -= 1;
}
let value = *counter;
drop(map);
self.observability
.metrics()
.set_in_flight(tenant_id.as_str(), value);
}
}
pub struct TenantCommandGuard {
tenant_id: TenantId,
manager: Arc<TenantPlacementManager>,
}
impl Drop for TenantCommandGuard {
fn drop(&mut self) {
let tenant_id = self.tenant_id.clone();
let manager = self.manager.clone();
tokio::spawn(async move {
manager.finish_command(&tenant_id).await;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::observability::Observability;
#[tokio::test]
async fn placement_rejects_unhosted_tenant() {
let obs = Arc::new(Observability::default());
let mgr = Arc::new(TenantPlacementManager::new(obs));
mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await;
let err = match mgr.begin_command(&TenantId::new("tenant-b")).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert!(matches!(err, AggregateError::TenantNotHosted { .. }));
}
#[tokio::test]
async fn drain_blocks_new_commands_until_in_flight_zero() {
let obs = Arc::new(Observability::default());
let mgr = Arc::new(TenantPlacementManager::new(obs));
mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await;
let guard = mgr.begin_command(&TenantId::new("tenant-a")).await.unwrap();
mgr.drain_tenant(&TenantId::new("tenant-a")).await;
let err = match mgr.begin_command(&TenantId::new("tenant-a")).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert!(matches!(err, AggregateError::TenantDraining { .. }));
drop(guard);
mgr.wait_drained(&TenantId::new("tenant-a")).await;
let err = match mgr.begin_command(&TenantId::new("tenant-a")).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert!(matches!(err, AggregateError::TenantDraining { .. }));
}
}