268 lines
8.0 KiB
Rust
268 lines
8.0 KiB
Rust
use crate::observability::Observability;
|
|
use crate::types::{AggregateError, TenantId};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
pub struct TenantStatus {
|
|
pub tenant_id: TenantId,
|
|
pub hosted: bool,
|
|
pub accepting: bool,
|
|
pub draining: bool,
|
|
pub in_flight: u64,
|
|
}
|
|
|
|
pub struct TenantPlacementManager {
|
|
hosted: RwLock<HashSet<String>>,
|
|
draining: RwLock<HashSet<String>>,
|
|
in_flight: RwLock<HashMap<String, u64>>,
|
|
observability: Arc<Observability>,
|
|
}
|
|
|
|
impl TenantPlacementManager {
|
|
pub fn new(observability: Arc<Observability>) -> Self {
|
|
Self {
|
|
hosted: RwLock::new(HashSet::new()),
|
|
draining: RwLock::new(HashSet::new()),
|
|
in_flight: RwLock::new(HashMap::new()),
|
|
observability,
|
|
}
|
|
}
|
|
|
|
pub async fn set_hosted_tenants(&self, tenant_ids: impl IntoIterator<Item = String>) {
|
|
let mut hosted = self.hosted.write().await;
|
|
hosted.clear();
|
|
hosted.extend(tenant_ids);
|
|
}
|
|
|
|
pub async fn apply_placement_map(&self, shard_id: &str, placement: &HashMap<String, String>) {
|
|
let tenants = placement
|
|
.iter()
|
|
.filter_map(|(tenant_id, assigned)| {
|
|
if assigned == shard_id {
|
|
Some(tenant_id.clone())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
self.set_hosted_tenants(tenants).await;
|
|
}
|
|
|
|
pub async fn is_hosted(&self, tenant_id: &TenantId) -> bool {
|
|
if tenant_id.as_str().is_empty() {
|
|
return true;
|
|
}
|
|
self.hosted.read().await.contains(tenant_id.as_str())
|
|
}
|
|
|
|
pub async fn is_draining(&self, tenant_id: &TenantId) -> bool {
|
|
self.draining.read().await.contains(tenant_id.as_str())
|
|
}
|
|
|
|
pub async fn begin_command(
|
|
self: &Arc<Self>,
|
|
tenant_id: &TenantId,
|
|
) -> Result<TenantCommandGuard, AggregateError> {
|
|
if !self.is_hosted(tenant_id).await {
|
|
return Err(AggregateError::TenantNotHosted {
|
|
tenant_id: tenant_id.clone(),
|
|
});
|
|
}
|
|
|
|
if self.is_draining(tenant_id).await {
|
|
return Err(AggregateError::TenantDraining {
|
|
tenant_id: tenant_id.clone(),
|
|
});
|
|
}
|
|
|
|
let mut map = self.in_flight.write().await;
|
|
let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0);
|
|
*counter += 1;
|
|
let value = *counter;
|
|
drop(map);
|
|
|
|
self.observability
|
|
.metrics()
|
|
.set_in_flight(tenant_id.as_str(), value);
|
|
|
|
Ok(TenantCommandGuard {
|
|
tenant_id: tenant_id.clone(),
|
|
manager: self.clone(),
|
|
})
|
|
}
|
|
|
|
pub async fn drain_tenant(&self, tenant_id: &TenantId) {
|
|
if tenant_id.as_str().is_empty() {
|
|
return;
|
|
}
|
|
let mut draining = self.draining.write().await;
|
|
draining.insert(tenant_id.as_str().to_string());
|
|
}
|
|
|
|
pub async fn undrain_tenant(&self, tenant_id: &TenantId) {
|
|
let mut draining = self.draining.write().await;
|
|
draining.remove(tenant_id.as_str());
|
|
}
|
|
|
|
pub async fn wait_drained(&self, tenant_id: &TenantId) {
|
|
loop {
|
|
let in_flight = self
|
|
.in_flight
|
|
.read()
|
|
.await
|
|
.get(tenant_id.as_str())
|
|
.copied()
|
|
.unwrap_or(0);
|
|
if in_flight == 0 {
|
|
break;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
}
|
|
}
|
|
|
|
pub async fn wait_drained_with_timeout(
|
|
&self,
|
|
tenant_id: &TenantId,
|
|
timeout: std::time::Duration,
|
|
) -> bool {
|
|
let deadline = tokio::time::Instant::now() + timeout;
|
|
loop {
|
|
let in_flight = self
|
|
.in_flight
|
|
.read()
|
|
.await
|
|
.get(tenant_id.as_str())
|
|
.copied()
|
|
.unwrap_or(0);
|
|
if in_flight == 0 {
|
|
return true;
|
|
}
|
|
if tokio::time::Instant::now() >= deadline {
|
|
return false;
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
|
}
|
|
}
|
|
|
|
pub async fn tenant_status(&self, tenant_id: &TenantId) -> TenantStatus {
|
|
let hosted = self.is_hosted(tenant_id).await;
|
|
let draining = self.is_draining(tenant_id).await;
|
|
let in_flight = self
|
|
.in_flight
|
|
.read()
|
|
.await
|
|
.get(tenant_id.as_str())
|
|
.copied()
|
|
.unwrap_or(0);
|
|
TenantStatus {
|
|
tenant_id: tenant_id.clone(),
|
|
hosted,
|
|
accepting: hosted && !draining,
|
|
draining,
|
|
in_flight,
|
|
}
|
|
}
|
|
|
|
pub async fn hosted_tenants(&self) -> Vec<TenantId> {
|
|
let hosted = self.hosted.read().await;
|
|
hosted.iter().map(TenantId::new).collect()
|
|
}
|
|
|
|
pub async fn all_statuses(&self) -> Vec<TenantStatus> {
|
|
let hosted = self.hosted.read().await.clone();
|
|
let draining = self.draining.read().await.clone();
|
|
let in_flight = self.in_flight.read().await.clone();
|
|
|
|
hosted
|
|
.into_iter()
|
|
.map(|id| {
|
|
let tenant_id = TenantId::new(id.clone());
|
|
let d = draining.contains(&id);
|
|
let f = in_flight.get(&id).copied().unwrap_or(0);
|
|
TenantStatus {
|
|
tenant_id,
|
|
hosted: true,
|
|
accepting: !d,
|
|
draining: d,
|
|
in_flight: f,
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
async fn finish_command(&self, tenant_id: &TenantId) {
|
|
let mut map = self.in_flight.write().await;
|
|
let counter = map.entry(tenant_id.as_str().to_string()).or_insert(0);
|
|
if *counter > 0 {
|
|
*counter -= 1;
|
|
}
|
|
let value = *counter;
|
|
drop(map);
|
|
|
|
self.observability
|
|
.metrics()
|
|
.set_in_flight(tenant_id.as_str(), value);
|
|
}
|
|
}
|
|
|
|
pub struct TenantCommandGuard {
|
|
tenant_id: TenantId,
|
|
manager: Arc<TenantPlacementManager>,
|
|
}
|
|
|
|
impl Drop for TenantCommandGuard {
|
|
fn drop(&mut self) {
|
|
let tenant_id = self.tenant_id.clone();
|
|
let manager = self.manager.clone();
|
|
tokio::spawn(async move {
|
|
manager.finish_command(&tenant_id).await;
|
|
});
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::observability::Observability;
|
|
|
|
#[tokio::test]
|
|
async fn placement_rejects_unhosted_tenant() {
|
|
let obs = Arc::new(Observability::default());
|
|
let mgr = Arc::new(TenantPlacementManager::new(obs));
|
|
mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await;
|
|
|
|
let err = match mgr.begin_command(&TenantId::new("tenant-b")).await {
|
|
Ok(_) => panic!("expected error"),
|
|
Err(e) => e,
|
|
};
|
|
assert!(matches!(err, AggregateError::TenantNotHosted { .. }));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn drain_blocks_new_commands_until_in_flight_zero() {
|
|
let obs = Arc::new(Observability::default());
|
|
let mgr = Arc::new(TenantPlacementManager::new(obs));
|
|
mgr.set_hosted_tenants(vec!["tenant-a".to_string()]).await;
|
|
|
|
let guard = mgr.begin_command(&TenantId::new("tenant-a")).await.unwrap();
|
|
mgr.drain_tenant(&TenantId::new("tenant-a")).await;
|
|
let err = match mgr.begin_command(&TenantId::new("tenant-a")).await {
|
|
Ok(_) => panic!("expected error"),
|
|
Err(e) => e,
|
|
};
|
|
assert!(matches!(err, AggregateError::TenantDraining { .. }));
|
|
|
|
drop(guard);
|
|
mgr.wait_drained(&TenantId::new("tenant-a")).await;
|
|
let err = match mgr.begin_command(&TenantId::new("tenant-a")).await {
|
|
Ok(_) => panic!("expected error"),
|
|
Err(e) => e,
|
|
};
|
|
assert!(matches!(err, AggregateError::TenantDraining { .. }));
|
|
}
|
|
}
|