feat(billing): implement tenant subscription entitlements system (milestones 0-6)
Some checks failed
ci / ui (push) Failing after 28s
ci / rust (push) Failing after 2m40s
images / build-and-push (push) Failing after 19s

This commit is contained in:
2026-03-30 18:41:23 +03:00
parent 5992044b7e
commit 2595e7f1c5
63 changed files with 8448 additions and 321 deletions

View File

@@ -1,7 +1,9 @@
use crate::{
AppState, RequestIds,
auth::{Principal, has_permission},
fleet,
config_registry::{ConfigDomain, ConfigRegistryError},
config_schemas::RoutingConfig,
drift, fleet,
job_engine::{JobEngine, StartJobError},
jobs::{Job, JobStatus, JobStep},
placement::{PlacementResponse, ServiceKind},
@@ -15,7 +17,9 @@ use axum::{
routing::{get, post},
};
use serde::Deserialize;
use sha2::Digest;
use std::time::{SystemTime, UNIX_EPOCH};
use url::Url;
use uuid::Uuid;
const HEADER_IDEMPOTENCY_KEY: &str = "idempotency-key";
@@ -25,21 +29,125 @@ pub fn admin_router() -> Router<AppState> {
Router::new()
.route("/whoami", get(whoami))
.route("/platform/info", get(platform_info))
.route("/platform/drift", get(platform_drift))
.route("/fleet/snapshot", get(fleet_snapshot))
.route("/tenants", get(list_tenants))
.route("/placement/{kind}", get(get_placement))
.route("/config", get(list_config))
.route("/config/{domain}", get(get_config))
.route("/config/{domain}/history", get(get_config_history))
.route("/jobs/platform/verify", post(start_platform_verify))
.route("/jobs/config/validate", post(start_config_validate))
.route("/jobs/config/apply", post(start_config_apply))
.route("/jobs/config/rollback", post(start_config_rollback))
.route("/tenants/echo", get(tenant_echo))
.route(
"/tenants/{tenant_id}/billing",
get(crate::billing::get_billing),
)
.route(
"/tenants/{tenant_id}/billing/checkout",
post(crate::billing::checkout),
)
.route(
"/tenants/{tenant_id}/billing/portal",
post(crate::billing::portal),
)
.route("/jobs/echo", post(create_echo_job))
.route("/jobs/{job_id}", get(get_job))
.route("/jobs/{job_id}/cancel", post(cancel_job))
.route("/jobs/tenant/drain", post(start_tenant_drain))
.route("/jobs/tenant/migrate", post(start_tenant_migrate))
.route("/plan/tenant/migrate", post(plan_tenant_migrate))
.route("/plan/config/apply", post(plan_config_apply))
.route("/audit", get(list_audit))
.route("/swarm/services", get(list_swarm_services))
.route("/swarm/services/{name}/tasks", get(list_swarm_tasks))
}
#[derive(Debug, Deserialize)]
struct PlatformVerifyRequest {
reason: String,
}
async fn start_platform_verify(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<PlatformVerifyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_platform_verify(state.clone(), &principal, body.reason, key) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn get_config_history(
State(state): State<AppState>,
Path(domain): Path<String>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::NOT_FOUND.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
let rows = match source.history_bytes(50).await {
Ok(items) => items
.into_iter()
.filter_map(|(rev, bytes)| {
let v = serde_json::from_slice::<serde_json::Value>(&bytes).ok()?;
Some(serde_json::json!({
"revision": rev,
"sha256": sha256_hex(&bytes),
"value": v
}))
})
.collect::<Vec<_>>(),
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
Err(_) => return StatusCode::NOT_IMPLEMENTED.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "domain": domain.as_str(), "items": rows })),
)
.into_response()
}
async fn whoami(Extension(principal): Extension<Principal>) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
@@ -70,6 +178,18 @@ async fn platform_info(Extension(principal): Extension<Principal>) -> impl IntoR
.into_response()
}
async fn platform_drift(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let r = drift::compute(&state).await;
(StatusCode::OK, Json(r)).into_response()
}
async fn fleet_snapshot(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
@@ -109,6 +229,434 @@ async fn get_placement(
(StatusCode::OK, Json(resp)).into_response()
}
async fn list_config(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domains: Vec<&'static str> = [ConfigDomain::Routing, ConfigDomain::Placement]
.into_iter()
.filter(|d| state.config.source(*d).is_some())
.map(|d| d.as_str())
.collect();
(
StatusCode::OK,
Json(serde_json::json!({ "domains": domains })),
)
.into_response()
}
async fn get_config(
State(state): State<AppState>,
Path(domain): Path<String>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::NOT_FOUND.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
let loaded = source.load_bytes().await;
let (bytes, revision) = match loaded {
Ok(x) => x,
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
Err(ConfigRegistryError::Decode(_)) => return StatusCode::BAD_REQUEST.into_response(),
Err(ConfigRegistryError::NotConfigured) => return StatusCode::NOT_FOUND.into_response(),
};
let json_value = match bytes {
Some(ref b) => match serde_json::from_slice::<serde_json::Value>(b) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": format!("invalid json: {e}") })),
)
.into_response();
}
},
None => serde_json::Value::Null,
};
let sha256 = bytes.as_deref().map(sha256_hex);
(
StatusCode::OK,
Json(serde_json::json!({
"domain": domain.as_str(),
"revision": revision,
"sha256": sha256,
"source": source.info(),
"value": json_value,
})),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct ConfigApplyRequest {
domain: String,
expected_revision: Option<u64>,
reason: String,
value: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ConfigValidateRequest {
domain: String,
reason: String,
value: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ConfigRollbackRequest {
domain: String,
reason: String,
}
fn parse_domain(domain: &str) -> Option<ConfigDomain> {
match domain {
"routing" => Some(ConfigDomain::Routing),
"placement" => Some(ConfigDomain::Placement),
_ => None,
}
}
async fn start_config_validate(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigValidateRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_config_validate(
state.clone(),
&principal,
domain,
body.reason,
body.value,
key,
) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn start_config_apply(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigApplyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_config_apply(
state.clone(),
&principal,
domain,
body.reason,
body.expected_revision,
body.value,
key,
) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn start_config_rollback(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigRollbackRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id =
match engine.start_config_rollback(state.clone(), &principal, domain, body.reason, key) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct ConfigPlanApplyRequest {
domain: String,
value: serde_json::Value,
}
async fn plan_config_apply(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigPlanApplyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match body.domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
// Validate proposed config (schema + semantics).
let validate_res: Result<(), String> = match domain {
ConfigDomain::Routing => {
let cfg = match serde_json::from_value::<RoutingConfig>(body.value.clone()) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response();
}
};
validate_routing_semantics(&cfg)
}
ConfigDomain::Placement => {
let cfg =
match serde_json::from_value::<crate::placement::PlacementFile>(body.value.clone())
{
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response();
}
};
validate_placement_semantics(&cfg)
}
};
if let Err(e) = validate_res {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e })),
)
.into_response();
}
let (cur_bytes, cur_rev) = match source.load_bytes().await {
Ok(x) => x,
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
};
let cur_value = cur_bytes
.as_deref()
.and_then(|b| serde_json::from_slice::<serde_json::Value>(b).ok())
.unwrap_or(serde_json::Value::Null);
let before = serde_json::to_string_pretty(&cur_value).unwrap_or_default();
let after = serde_json::to_string_pretty(&body.value).unwrap_or_default();
let changed = cur_value != body.value;
let impacted_services: Vec<&'static str> = match domain {
ConfigDomain::Routing => vec!["gateway"],
ConfigDomain::Placement => vec!["gateway", "control-api"],
};
(
StatusCode::OK,
Json(serde_json::json!({
"domain": domain.as_str(),
"current_revision": cur_rev,
"changed": changed,
"impacted_services": impacted_services,
"diff": {
"before": before,
"after": after,
}
})),
)
.into_response()
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut h = sha2::Sha256::new();
h.update(bytes);
hex::encode(h.finalize())
}
fn validate_routing_semantics(cfg: &RoutingConfig) -> Result<(), String> {
let shard_maps = [
("aggregate_shards", &cfg.aggregate_shards),
("projection_shards", &cfg.projection_shards),
("runner_shards", &cfg.runner_shards),
];
for (name, map) in shard_maps {
for (shard_id, endpoints) in map {
if endpoints.is_empty() {
return Err(format!("{name}[{shard_id}] has no endpoints"));
}
for ep in endpoints {
let u = Url::parse(ep)
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
if u.scheme() != "http" && u.scheme() != "https" {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
));
}
if u.host_str().is_none() {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must include host"
));
}
}
}
}
let placements = [
(
"aggregate_placement",
&cfg.aggregate_placement,
&cfg.aggregate_shards,
),
(
"projection_placement",
&cfg.projection_placement,
&cfg.projection_shards,
),
(
"runner_placement",
&cfg.runner_placement,
&cfg.runner_shards,
),
];
for (pname, pmap, shards) in placements {
for (tenant, shard_id) in pmap {
if shard_id.trim().is_empty() {
return Err(format!("{pname}[{tenant}] shard_id is empty"));
}
if !shards.contains_key(shard_id) {
return Err(format!(
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
));
}
}
}
Ok(())
}
fn validate_placement_semantics(cfg: &crate::placement::PlacementFile) -> Result<(), String> {
let kinds = [
("aggregate_placement", cfg.aggregate_placement.as_ref()),
("projection_placement", cfg.projection_placement.as_ref()),
("runner_placement", cfg.runner_placement.as_ref()),
];
for (kind, k) in kinds {
let Some(k) = k else { continue };
for p in &k.placements {
if p.targets.is_empty() {
return Err(format!("{kind} tenant {} has no targets", p.tenant_id));
}
if p.targets.iter().any(|t| t.trim().is_empty()) {
return Err(format!("{kind} tenant {} has empty target", p.tenant_id));
}
}
}
Ok(())
}
async fn list_tenants(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
@@ -256,6 +804,7 @@ async fn start_tenant_drain(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_tenant_drain(
state.clone(),
@@ -298,6 +847,7 @@ async fn start_tenant_migrate(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_tenant_migrate(
state.clone(),

904
control/api/src/billing.rs Normal file
View File

@@ -0,0 +1,904 @@
use crate::{
AppState,
auth::{Principal, has_permission},
};
use async_trait::async_trait;
use axum::{
Json,
extract::{Extension, Path, State},
http::{HeaderMap, StatusCode},
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{
collections::BTreeMap,
fs,
path::PathBuf,
sync::{Arc, RwLock},
time::SystemTime,
};
use thiserror::Error;
use uuid::Uuid;
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
fn verify_tenant_isolation(headers: &HeaderMap, path_tenant_id: Uuid) -> Result<(), StatusCode> {
let header_tenant_id = headers
.get(HEADER_TENANT_ID)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)
.and_then(|s| Uuid::parse_str(s).map_err(|_| StatusCode::BAD_REQUEST))?;
if header_tenant_id != path_tenant_id {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Plan {
Free,
Pro,
Enterprise,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionStatus {
Trialing,
Active,
PastDue,
Paused,
Canceled,
Incomplete,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Entitlements {
pub max_deployments: u32,
pub max_runners: u32,
pub s3_docs_enabled: bool,
pub support_tier: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum BillingEvent {
SubscriptionCreated {
tenant_id: Uuid,
event_id: String,
provider_customer_id: String,
provider_subscription_id: String,
status: SubscriptionStatus,
plan: Plan,
current_period_end: String,
ts_ms: u64,
},
SubscriptionUpdated {
tenant_id: Uuid,
event_id: String,
status: SubscriptionStatus,
plan: Plan,
current_period_end: String,
cancel_at_period_end: bool,
ts_ms: u64,
},
SubscriptionDeleted {
tenant_id: Uuid,
event_id: String,
ts_ms: u64,
},
}
impl BillingEvent {
pub fn tenant_id(&self) -> Uuid {
match self {
Self::SubscriptionCreated { tenant_id, .. } => *tenant_id,
Self::SubscriptionUpdated { tenant_id, .. } => *tenant_id,
Self::SubscriptionDeleted { tenant_id, .. } => *tenant_id,
}
}
pub fn event_id(&self) -> &str {
match self {
Self::SubscriptionCreated { event_id, .. } => event_id,
Self::SubscriptionUpdated { event_id, .. } => event_id,
Self::SubscriptionDeleted { event_id, .. } => event_id,
}
}
pub fn ts_ms(&self) -> u64 {
match self {
Self::SubscriptionCreated { ts_ms, .. } => *ts_ms,
Self::SubscriptionUpdated { ts_ms, .. } => *ts_ms,
Self::SubscriptionDeleted { ts_ms, .. } => *ts_ms,
}
}
}
impl Entitlements {
pub fn derive(plan: Option<&Plan>, status: Option<&SubscriptionStatus>) -> Self {
let is_active = matches!(
status,
Some(SubscriptionStatus::Trialing | SubscriptionStatus::Active)
);
if !is_active {
return Self {
max_deployments: 1,
max_runners: 1,
s3_docs_enabled: false,
support_tier: "community".to_string(),
};
}
match plan.unwrap_or(&Plan::Free) {
Plan::Free => Self {
max_deployments: 3,
max_runners: 1,
s3_docs_enabled: false,
support_tier: "community".to_string(),
},
Plan::Pro => Self {
max_deployments: 10,
max_runners: 5,
s3_docs_enabled: true,
support_tier: "standard".to_string(),
},
Plan::Enterprise => Self {
max_deployments: 1000,
max_runners: 50,
s3_docs_enabled: true,
support_tier: "priority".to_string(),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TenantBillingState {
pub provider: String,
pub provider_customer_id: Option<String>,
pub provider_subscription_id: Option<String>,
pub provider_checkout_session_id: Option<String>,
pub status: Option<SubscriptionStatus>,
pub plan: Option<Plan>,
pub current_period_end: Option<String>,
pub cancel_at_period_end: Option<bool>,
pub processed_webhook_event_ids: Vec<String>,
pub updated_at: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BillingStateFile {
pub revision: Option<String>,
pub tenants: BTreeMap<Uuid, TenantBillingState>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BillingResponse {
pub configured: bool,
pub provider: Option<String>,
pub plan: Option<Plan>,
pub status: Option<SubscriptionStatus>,
pub current_period_end: Option<String>,
pub cancel_at_period_end: Option<bool>,
pub entitlements: Entitlements,
}
#[derive(Clone)]
pub struct BillingStore {
inner: Arc<RwLock<Inner>>,
}
struct Inner {
path: PathBuf,
last_modified: Option<SystemTime>,
cached: Option<BillingStateFile>,
}
impl BillingStore {
pub fn new(path: PathBuf) -> Self {
Self {
inner: Arc::new(RwLock::new(Inner {
path,
last_modified: None,
cached: None,
})),
}
}
pub fn get_for_tenant(&self, tenant_id: Uuid) -> BillingResponse {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
if let Some(state) = inner
.cached
.as_ref()
.and_then(|file| file.tenants.get(&tenant_id))
{
return BillingResponse {
configured: true,
provider: Some(state.provider.clone()),
plan: state.plan.clone(),
status: state.status.clone(),
current_period_end: state.current_period_end.clone(),
cancel_at_period_end: state.cancel_at_period_end,
entitlements: Entitlements::derive(state.plan.as_ref(), state.status.as_ref()),
};
}
BillingResponse {
configured: false,
provider: None,
plan: None,
status: None,
current_period_end: None,
cancel_at_period_end: None,
entitlements: Entitlements::derive(None, None),
}
}
pub fn get_all_tenant_ids(&self) -> Vec<Uuid> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
inner
.cached
.as_ref()
.map(|f| f.tenants.keys().cloned().collect())
.unwrap_or_default()
}
pub fn get_subscription_id(&self, tenant_id: Uuid) -> Option<String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
inner
.cached
.as_ref()
.and_then(|f| f.tenants.get(&tenant_id))
.and_then(|s| s.provider_subscription_id.clone())
}
pub fn apply_event(&self, event: BillingEvent) -> Result<(), String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
revision: Some("dev".to_string()),
tenants: BTreeMap::new(),
});
let tenant_id = event.tenant_id();
let event_id = event.event_id().to_string();
let ts_ms = event.ts_ms();
let state = file.tenants.entry(tenant_id).or_insert(TenantBillingState {
provider: "unknown".to_string(), // Will be updated by Created event
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: None,
plan: None,
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 0,
});
// Deduplication
if state.processed_webhook_event_ids.contains(&event_id) {
return Ok(());
}
// Monotonicity check
if state.updated_at > ts_ms {
state.processed_webhook_event_ids.push(event_id);
state.processed_webhook_event_ids.truncate(50);
inner.save(file)?;
return Ok(());
}
match event {
BillingEvent::SubscriptionCreated {
provider_customer_id,
provider_subscription_id,
status,
plan,
current_period_end,
..
} => {
state.provider_customer_id = Some(provider_customer_id);
state.provider_subscription_id = Some(provider_subscription_id);
state.status = Some(status);
state.plan = Some(plan);
state.current_period_end = Some(current_period_end);
}
BillingEvent::SubscriptionUpdated {
status,
plan,
current_period_end,
cancel_at_period_end,
..
} => {
state.status = Some(status);
state.plan = Some(plan);
state.current_period_end = Some(current_period_end);
state.cancel_at_period_end = Some(cancel_at_period_end);
}
BillingEvent::SubscriptionDeleted { .. } => {
state.status = Some(SubscriptionStatus::Canceled);
}
}
state.updated_at = ts_ms;
state.processed_webhook_event_ids.push(event_id);
state.processed_webhook_event_ids.truncate(50);
inner.save(file)?;
Ok(())
}
#[cfg(test)]
pub fn update_tenant_state(
&self,
tenant_id: Uuid,
state: TenantBillingState,
) -> Result<String, String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
revision: Some("dev".to_string()),
tenants: BTreeMap::new(),
});
file.tenants.insert(tenant_id, state);
inner.save(file)
}
}
impl Inner {
fn save(&mut self, mut file: BillingStateFile) -> Result<String, String> {
let revision = format!("rev-{}", Uuid::new_v4());
file.revision = Some(revision.clone());
let raw = serde_json::to_string_pretty(&file).map_err(|e| e.to_string())?;
let tmp = self.path.with_extension("json.tmp");
fs::write(&tmp, raw).map_err(|e| e.to_string())?;
fs::rename(&tmp, &self.path).map_err(|e| e.to_string())?;
self.last_modified = None;
self.cached = Some(file);
Ok(revision)
}
fn reload_if_changed(&mut self) {
let meta = fs::metadata(&self.path).ok();
let modified = meta.and_then(|m| m.modified().ok());
if self.cached.is_some() && modified.is_some() && modified == self.last_modified {
return;
}
self.last_modified = modified;
let p = &self.path;
self.cached = fs::read_to_string(p)
.ok()
.and_then(|raw| serde_json::from_str(&raw).ok());
}
}
pub async fn get_billing(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
let resp = state.billing.get_for_tenant(tenant_id);
(StatusCode::OK, Json(resp)).into_response()
}
#[derive(Debug, Deserialize)]
pub struct CheckoutRequest {
pub plan: Plan,
pub return_path: Option<String>,
}
pub async fn checkout(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<CheckoutRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
// Check if subscription already exists and is active/trialing
let current = state.billing.get_for_tenant(tenant_id);
if current.configured
&& matches!(
current.status,
Some(SubscriptionStatus::Active | SubscriptionStatus::Trialing)
)
{
return (
StatusCode::CONFLICT,
Json(serde_json::json!({ "error": "tenant already has an active subscription" })),
)
.into_response();
}
// Construct full return URL
// TODO: Validate return_path against ALLOWED_RETURN_ORIGINS if provided
let return_url = body.return_path.unwrap_or_else(|| "/billing".to_string());
match state
.billing_provider
.create_checkout_session(tenant_id, body.plan, return_url)
.await
{
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
Err(e) => {
let err_msg = e.to_string();
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": err_msg })),
)
.into_response()
}
}
}
pub async fn portal(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
let return_url = "/billing".to_string();
match state
.billing_provider
.create_portal_session(tenant_id, return_url)
.await
{
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
Err(e) => {
let err_msg = e.to_string();
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": err_msg })),
)
.into_response()
}
}
}
pub async fn webhook(
State(state): State<AppState>,
Path(_provider): Path<String>,
headers: HeaderMap,
body: axum::body::Bytes,
) -> impl IntoResponse {
// Note: We don't require auth here as this is a public endpoint called by the provider.
// Security is handled via signature verification in the provider trait.
match state.billing_provider.verify_webhook(&body, &headers).await {
Ok(event) => {
metrics::counter!("billing_webhook_requests_total", "status" => "success").increment(1);
if let Err(e) = state.billing.apply_event(event) {
tracing::error!(error = %e, "failed to apply billing event from webhook");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e })),
)
.into_response();
}
StatusCode::OK.into_response()
}
Err(e) => {
metrics::counter!("billing_webhook_requests_total", "status" => "error").increment(1);
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response()
}
}
}
pub async fn run_reconciliation_loop(state: AppState) {
let interval_secs = std::env::var("CONTROL_BILLING_RECONCILE_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3600);
tracing::info!(interval_secs, "starting billing reconciliation loop");
loop {
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
tracing::info!("starting billing reconciliation run");
reconcile_once(&state).await;
// Update tenant status gauges
// Note: This is an expensive operation if there are many tenants,
// but for reconciliation it's fine once per hour.
update_billing_gauges(&state);
}
}
pub async fn reconcile_once(state: &AppState) {
let start = std::time::Instant::now();
let tenant_ids = state.billing.get_all_tenant_ids();
let mut success = 0;
let mut error = 0;
let mut skipped = 0;
for tenant_id in tenant_ids {
let sub_id = state.billing.get_subscription_id(tenant_id);
if let Some(subscription_id) = sub_id {
match state
.billing_provider
.fetch_subscription(tenant_id, &subscription_id)
.await
{
Ok(event) => {
if let Err(e) = state.billing.apply_event(event) {
tracing::error!(?tenant_id, error = %e, "failed to apply reconciled billing event");
error += 1;
} else {
success += 1;
}
}
Err(e) => {
tracing::error!(?tenant_id, error = %e, "failed to fetch subscription for reconciliation");
error += 1;
}
}
} else {
skipped += 1;
}
}
let elapsed = start.elapsed();
metrics::counter!("billing_reconciliation_runs_total", "result" => "done").increment(1);
metrics::histogram!("billing_reconciliation_duration_ms").record(elapsed.as_millis() as f64);
tracing::info!(
success,
error,
skipped,
duration_ms = elapsed.as_millis(),
"billing reconciliation run complete"
);
}
fn update_billing_gauges(state: &AppState) {
let tenant_ids = state.billing.get_all_tenant_ids();
let mut counts: BTreeMap<(String, String), u64> = BTreeMap::new();
for tenant_id in tenant_ids {
let resp = state.billing.get_for_tenant(tenant_id);
let plan = match resp.plan {
Some(Plan::Free) => "free",
Some(Plan::Pro) => "pro",
Some(Plan::Enterprise) => "enterprise",
None => "none",
}
.to_string();
let status = match resp.status {
Some(SubscriptionStatus::Active) => "active",
Some(SubscriptionStatus::Trialing) => "trialing",
Some(SubscriptionStatus::PastDue) => "past_due",
Some(SubscriptionStatus::Paused) => "paused",
Some(SubscriptionStatus::Canceled) => "canceled",
Some(SubscriptionStatus::Incomplete) => "incomplete",
None => "none",
}
.to_string();
*counts.entry((plan, status)).or_insert(0) += 1;
}
for ((plan, status), count) in counts {
metrics::gauge!("billing_tenant_status_count", "plan" => plan, "status" => status)
.set(count as f64);
}
}
#[derive(Debug, Error)]
pub enum BillingError {
#[error("provider error: {0}")]
Provider(String),
#[error("invalid configuration: {0}")]
Config(String),
}
#[async_trait]
pub trait BillingProvider: Send + Sync {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
plan: Plan,
return_url: String,
) -> Result<String, BillingError>;
async fn create_portal_session(
&self,
tenant_id: Uuid,
return_url: String,
) -> Result<String, BillingError>;
async fn verify_webhook(
&self,
payload: &[u8],
headers: &HeaderMap,
) -> Result<BillingEvent, BillingError>;
async fn fetch_subscription(
&self,
tenant_id: Uuid,
subscription_id: &str,
) -> Result<BillingEvent, BillingError>;
}
pub struct StripeProvider {
pub secret_key: String,
pub price_pro: String,
pub price_enterprise: String,
}
#[async_trait]
impl BillingProvider for StripeProvider {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
plan: Plan,
_return_url: String,
) -> Result<String, BillingError> {
let _price = match plan {
Plan::Pro => &self.price_pro,
Plan::Enterprise => &self.price_enterprise,
Plan::Free => {
return Err(BillingError::Config(
"Free plan has no checkout".to_string(),
));
}
};
// TODO: Actually call Stripe API
// For now, returning a simulated Stripe checkout URL
Ok(format!(
"https://checkout.stripe.com/pay/cs_test_{}?tenant_id={}",
Uuid::new_v4(),
tenant_id
))
}
async fn create_portal_session(
&self,
tenant_id: Uuid,
_return_url: String,
) -> Result<String, BillingError> {
// TODO: Actually call Stripe API
Ok(format!(
"https://billing.stripe.com/p/session/ps_test_{}?tenant_id={}",
Uuid::new_v4(),
tenant_id
))
}
async fn verify_webhook(
&self,
_payload: &[u8],
_headers: &HeaderMap,
) -> Result<BillingEvent, BillingError> {
// TODO: Implement real Stripe signature verification
Err(BillingError::Provider("Not implemented".to_string()))
}
async fn fetch_subscription(
&self,
_tenant_id: Uuid,
_subscription_id: &str,
) -> Result<BillingEvent, BillingError> {
// TODO: Actually call Stripe API with timeout
// let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()...
Err(BillingError::Provider("Not implemented".to_string()))
}
}
pub struct MockProvider;
#[async_trait]
impl BillingProvider for MockProvider {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
_plan: Plan,
_return_url: String,
) -> Result<String, BillingError> {
Ok(format!("https://mock.stripe.com/checkout/{}", tenant_id))
}
async fn create_portal_session(
&self,
tenant_id: Uuid,
_return_url: String,
) -> Result<String, BillingError> {
Ok(format!("https://mock.stripe.com/portal/{}", tenant_id))
}
async fn verify_webhook(
&self,
payload: &[u8],
_headers: &HeaderMap,
) -> Result<BillingEvent, BillingError> {
// Mock implementation: just parse the payload as a BillingEvent
serde_json::from_slice(payload).map_err(|e| BillingError::Provider(e.to_string()))
}
async fn fetch_subscription(
&self,
tenant_id: Uuid,
_subscription_id: &str,
) -> Result<BillingEvent, BillingError> {
// Mock implementation: return a SubscriptionUpdated event with current state
// In a real mock we might want to store expectations, but for now we just return something plausible.
Ok(BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: format!("reconcile-{}", Uuid::new_v4()),
status: SubscriptionStatus::Active,
plan: Plan::Pro,
current_period_end: "2099-12-31T23:59:59Z".to_string(),
cancel_at_period_end: false,
ts_ms: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
})
}
}
impl MockProvider {
pub fn get_checkout_url(tenant: Uuid) -> String {
format!("https://mock.stripe.com/checkout/{}", tenant)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env::temp_dir;
#[test]
fn test_entitlement_derivation() {
let e = Entitlements::derive(Some(&Plan::Free), Some(&SubscriptionStatus::PastDue));
assert_eq!(e.max_deployments, 1);
let e = Entitlements::derive(Some(&Plan::Pro), Some(&SubscriptionStatus::Active));
assert_eq!(e.max_deployments, 10);
assert!(e.s3_docs_enabled);
let e = Entitlements::derive(Some(&Plan::Enterprise), Some(&SubscriptionStatus::Trialing));
assert_eq!(e.max_deployments, 1000);
}
#[test]
fn test_billing_state_roundtrip() {
let mut path = temp_dir();
path.push(format!("billing-{}.json", Uuid::new_v4()));
let store = BillingStore::new(path.clone());
let tenant_id = Uuid::new_v4();
let resp = store.get_for_tenant(tenant_id);
assert!(!resp.configured);
assert_eq!(resp.entitlements.max_deployments, 1);
let state = TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::Active),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 0,
};
store.update_tenant_state(tenant_id, state).unwrap();
let resp2 = store.get_for_tenant(tenant_id);
assert!(resp2.configured);
assert_eq!(resp2.provider.as_deref(), Some("mock"));
assert_eq!(resp2.plan, Some(Plan::Pro));
assert_eq!(resp2.entitlements.max_deployments, 10);
let _ = fs::remove_file(path);
}
#[tokio::test]
async fn test_reconciliation_corrects_state() {
let mut path = temp_dir();
path.push(format!("billing-reconcile-{}.json", Uuid::new_v4()));
let store = BillingStore::new(path.clone());
let tenant_id = Uuid::new_v4();
// 1. Initial state: PastDue
store
.update_tenant_state(
tenant_id,
TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: Some("cus_1".to_string()),
provider_subscription_id: Some("sub_1".to_string()),
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::PastDue),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 100,
},
)
.unwrap();
let state = AppState {
prometheus: crate::get_test_prometheus_handle(),
auth: crate::AuthConfig { hs256_secret: None },
jobs: crate::jobs::JobStore::default(),
audit: crate::AuditStore::default(),
tenant_locks: crate::job_engine::TenantLocks::default(),
config_locks: crate::job_engine::ConfigLocks::default(),
http: reqwest::Client::new(),
placement: crate::placement::PlacementStore::new(temp_dir().join("placement.json")),
billing: store.clone(),
billing_provider: Arc::new(MockProvider),
billing_enforcement_enabled: true,
config: crate::config_registry::ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: crate::swarm::SwarmStore::new(temp_dir().join("swarm.json")),
docs: None,
};
// 2. Run reconciliation. MockProvider returns Active status.
reconcile_once(&state).await;
// 3. Verify state is now Active
let resp = store.get_for_tenant(tenant_id);
assert_eq!(resp.status, Some(SubscriptionStatus::Active));
let _ = fs::remove_file(path);
}
}

View File

@@ -0,0 +1,323 @@
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc, time::Duration};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConfigDomain {
Routing,
Placement,
}
impl ConfigDomain {
pub fn as_str(&self) -> &'static str {
match self {
ConfigDomain::Routing => "routing",
ConfigDomain::Placement => "placement",
}
}
}
#[derive(Debug, Error)]
pub enum ConfigRegistryError {
#[error("source error: {0}")]
Source(String),
#[error("decode error: {0}")]
Decode(String),
#[error("domain not configured")]
NotConfigured,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConfigSnapshot<T> {
pub domain: String,
pub revision: u64,
pub value: T,
pub source: ConfigSourceInfo,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ConfigSourceInfo {
File { path: String },
NatsKv { bucket: String, key: String },
Fixed,
}
#[async_trait]
pub trait ConfigSource: Send + Sync {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError>;
async fn put_bytes(
&self,
expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError>;
async fn history_bytes(&self, limit: usize)
-> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError>;
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
>;
fn info(&self) -> ConfigSourceInfo;
}
#[derive(Clone)]
pub struct FixedSource {
bytes: Arc<Vec<u8>>,
}
impl FixedSource {
pub fn new(bytes: Vec<u8>) -> Self {
Self {
bytes: Arc::new(bytes),
}
}
}
#[async_trait]
impl ConfigSource for FixedSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
Ok((Some(self.bytes.as_ref().clone()), 1))
}
async fn put_bytes(
&self,
_expected_revision: Option<u64>,
_value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"fixed source is read-only".to_string(),
))
}
async fn history_bytes(
&self,
_limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"fixed source has no history".to_string(),
))
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
Ok(Box::pin(futures::stream::empty()))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::Fixed
}
}
#[derive(Clone)]
pub struct FileSource {
path: PathBuf,
}
impl FileSource {
pub fn new(path: PathBuf) -> Self {
Self { path }
}
}
#[async_trait]
impl ConfigSource for FileSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
let raw = tokio::fs::read(&self.path)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok((Some(raw), 0))
}
async fn put_bytes(
&self,
_expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
let tmp = self.path.with_extension("tmp");
tokio::fs::write(&tmp, &value)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
tokio::fs::rename(&tmp, &self.path)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(0)
}
async fn history_bytes(
&self,
_limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"file source has no history".to_string(),
))
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
Ok(Box::pin(futures::stream::empty()))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::File {
path: self.path.to_string_lossy().to_string(),
}
}
}
#[derive(Clone)]
pub struct NatsKvSource {
kv: async_nats::jetstream::kv::Store,
bucket: String,
key: String,
}
impl NatsKvSource {
pub async fn connect(
nats_url: impl Into<String>,
bucket: impl Into<String>,
key: impl Into<String>,
) -> Result<Self, ConfigRegistryError> {
let nats_url = nats_url.into();
let bucket = bucket.into();
let key = key.into();
let client = tokio::time::timeout(Duration::from_secs(2), async_nats::connect(nats_url))
.await
.map_err(|_| ConfigRegistryError::Source("connect timeout".to_string()))?
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
let jetstream = async_nats::jetstream::new(client);
let kv = match jetstream.get_key_value(&bucket).await {
Ok(kv) => kv,
Err(_) => jetstream
.create_key_value(async_nats::jetstream::kv::Config {
bucket: bucket.clone(),
..Default::default()
})
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
};
Ok(Self { kv, bucket, key })
}
}
#[async_trait]
impl ConfigSource for NatsKvSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
let entry = self
.kv
.entry(&self.key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(match entry {
Some(e) => (Some(e.value.to_vec()), e.revision),
None => (None, 0),
})
}
async fn put_bytes(
&self,
expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
let rev = match expected_revision {
Some(expected) if expected > 0 => self
.kv
.update(&self.key, value.into(), expected)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
_ => self
.kv
.put(&self.key, value.into())
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
};
Ok(rev)
}
async fn history_bytes(
&self,
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
let mut stream = self
.kv
.history(&self.key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
let mut out = Vec::new();
while let Some(item) = stream.next().await {
let entry = item.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
out.push((entry.revision, entry.value.to_vec()));
if out.len() >= limit {
break;
}
}
Ok(out)
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
let key = self.key.clone();
let watch = self
.kv
.watch(&key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(Box::pin(watch.filter_map(|entry| async move {
match entry {
Ok(entry) => match entry.operation {
async_nats::jetstream::kv::Operation::Put => Some(Ok(())),
async_nats::jetstream::kv::Operation::Delete
| async_nats::jetstream::kv::Operation::Purge => None,
},
Err(e) => Some(Err(ConfigRegistryError::Source(e.to_string()))),
}
})))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::NatsKv {
bucket: self.bucket.clone(),
key: self.key.clone(),
}
}
}
#[derive(Clone)]
pub struct ConfigRegistry {
routing: Option<Arc<dyn ConfigSource>>,
placement: Option<Arc<dyn ConfigSource>>,
}
impl ConfigRegistry {
pub fn new(
routing: Option<Arc<dyn ConfigSource>>,
placement: Option<Arc<dyn ConfigSource>>,
) -> Self {
Self { routing, placement }
}
pub fn source(&self, domain: ConfigDomain) -> Option<Arc<dyn ConfigSource>> {
match domain {
ConfigDomain::Routing => self.routing.clone(),
ConfigDomain::Placement => self.placement.clone(),
}
}
}

View File

@@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RoutingConfig {
pub revision: u64,
pub aggregate_placement: HashMap<String, String>,
pub projection_placement: HashMap<String, String>,
pub runner_placement: HashMap<String, String>,
pub aggregate_shards: HashMap<String, Vec<String>>,
pub projection_shards: HashMap<String, Vec<String>>,
pub runner_shards: HashMap<String, Vec<String>>,
}

View File

@@ -0,0 +1,353 @@
use crate::auth::{Principal, has_permission};
use crate::{AppState, RequestIds};
use axum::{
Router,
body::Bytes,
extract::{Extension, Path, Query, State},
http::{HeaderMap, StatusCode, header},
response::IntoResponse,
routing::{get, post, put},
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
pub fn router() -> Router<AppState> {
Router::new()
.route("/tenants/{tenant_id}/docs", get(list_docs))
.route(
"/tenants/{tenant_id}/docs/{doc_type}/{doc_id}/{filename}",
put(upload_doc),
)
.route(
"/tenants/{tenant_id}/docs/object/{*key}",
get(get_doc).delete(delete_doc),
)
.route(
"/tenants/{tenant_id}/docs/presign/upload",
post(presign_upload),
)
.route(
"/tenants/{tenant_id}/docs/presign/download",
post(presign_download),
)
}
fn ensure_tenant_header(headers: &HeaderMap, tenant_id: Uuid) -> Result<(), StatusCode> {
let header_tid = headers
.get(HEADER_TENANT_ID)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)?;
let header_tid = Uuid::parse_str(header_tid).map_err(|_| StatusCode::BAD_REQUEST)?;
if header_tid != tenant_id {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
fn ensure_docs_enabled(state: &AppState, tenant_id: Uuid) -> Result<(), StatusCode> {
if !state.billing_enforcement_enabled {
return Ok(());
}
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
if !entitlements.s3_docs_enabled {
return Err(StatusCode::PAYMENT_REQUIRED);
}
Ok(())
}
#[derive(Debug, Deserialize)]
struct ListQuery {
prefix: Option<String>,
}
#[derive(Debug, Serialize)]
struct ListResponse {
objects: Vec<crate::s3_docs::DocObject>,
}
async fn list_docs(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Query(q): Query<ListQuery>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let prefix = q.prefix.unwrap_or_default();
let prefix = prefix.trim();
if prefix.contains("..") {
return StatusCode::BAD_REQUEST.into_response();
}
let base = format!("{}{}", store_prefix(store), tenant_id);
let prefix = if prefix.is_empty() {
format!("{base}/")
} else {
format!("{base}/{prefix}")
};
match store.list_for_tenant(&tenant_id.to_string(), &prefix).await {
Ok(objects) => (StatusCode::OK, axum::Json(ListResponse { objects })).into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
async fn upload_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, doc_type, doc_id, filename)): Path<(Uuid, String, String, String)>,
Extension(principal): Extension<Principal>,
Extension(request_ids): Extension<RequestIds>,
body: Bytes,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let ct = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let key = match store.key_for(&tenant_id.to_string(), &doc_type, &doc_id, &filename) {
Ok(k) => k,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let bytes = body.to_vec();
let hash = crate::s3_docs::DocsStore::content_hash_sha256_hex(&bytes);
if let Err(e) = store
.put_for_tenant(&tenant_id.to_string(), &key, bytes, ct)
.await
{
tracing::warn!(
request_id = %request_ids.request_id,
correlation_id = ?request_ids.correlation_id,
error = %e,
"docs upload failed"
);
return StatusCode::BAD_GATEWAY.into_response();
}
(
StatusCode::OK,
axum::Json(serde_json::json!({
"key": key,
"sha256": hash,
})),
)
.into_response()
}
async fn get_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, key)): Path<(Uuid, String)>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store
.get_bytes_for_tenant(&tenant_id.to_string(), &key)
.await
{
Ok((bytes, ct)) => {
let mut res = axum::response::Response::new(axum::body::Body::from(bytes));
*res.status_mut() = StatusCode::OK;
if let Some(ct) = ct
&& let Ok(v) = axum::http::HeaderValue::from_str(&ct)
{
res.headers_mut().insert(header::CONTENT_TYPE, v);
}
res
}
Err(_) => StatusCode::NOT_FOUND.into_response(),
}
}
async fn delete_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, key)): Path<(Uuid, String)>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store.delete_for_tenant(&tenant_id.to_string(), &key).await {
Ok(_) => StatusCode::NO_CONTENT.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
#[derive(Debug, Deserialize)]
struct PresignUploadRequest {
doc_type: String,
doc_id: Option<String>,
filename: String,
content_type: Option<String>,
}
async fn presign_upload(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Extension(principal): Extension<Principal>,
axum::Json(body): axum::Json<PresignUploadRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let doc_id = body.doc_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let key = match store.key_for(
&tenant_id.to_string(),
&body.doc_type,
&doc_id,
&body.filename,
) {
Ok(k) => k,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
match store
.presign_put_for_tenant(
&tenant_id.to_string(),
&key,
body.content_type,
std::time::Duration::from_secs(300),
)
.await
{
Ok(url) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"method": "PUT",
"url": url,
"key": key,
})),
)
.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
#[derive(Debug, Deserialize)]
struct PresignDownloadRequest {
key: String,
}
async fn presign_download(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Extension(principal): Extension<Principal>,
axum::Json(body): axum::Json<PresignDownloadRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !body.key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store
.presign_get_for_tenant(
&tenant_id.to_string(),
&body.key,
std::time::Duration::from_secs(300),
)
.await
{
Ok(url) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"method": "GET",
"url": url,
"key": body.key,
})),
)
.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
fn store_prefix(store: &crate::s3_docs::DocsStore) -> &str {
store.prefix()
}

127
control/api/src/drift.rs Normal file
View File

@@ -0,0 +1,127 @@
use crate::{AppState, build_info::extract_build_info, fleet, swarm::SwarmService};
use serde::Serialize;
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DriftKind {
Missing,
Extra,
Unhealthy,
VersionMismatch,
}
#[derive(Debug, Clone, Serialize)]
pub struct DriftItem {
pub kind: DriftKind,
pub service: String,
pub details: serde_json::Value,
}
#[derive(Debug, Clone, Serialize)]
pub struct DriftResponse {
pub summary: BTreeMap<String, u64>,
pub items: Vec<DriftItem>,
}
pub async fn compute(state: &AppState) -> DriftResponse {
let mut items: Vec<DriftItem> = Vec::new();
// Desired service set: what the Control API was configured to observe.
// (In production, this should evolve into "desired stacks + required services".)
let desired: BTreeSet<String> = state
.fleet_services
.iter()
.map(|s| s.name.clone())
.collect();
// Observed service set: what Swarm reports (dev: from file snapshot).
let observed_services: Vec<SwarmService> = state.swarm.list_services();
let observed: BTreeSet<String> = observed_services.iter().map(|s| s.name.clone()).collect();
for missing in desired.difference(&observed) {
items.push(DriftItem {
kind: DriftKind::Missing,
service: missing.clone(),
details: serde_json::json!({ "expected": true }),
});
}
for extra in observed.difference(&desired) {
items.push(DriftItem {
kind: DriftKind::Extra,
service: extra.clone(),
details: serde_json::json!({ "observed": true }),
});
}
// Health drift: based on fleet snapshot.
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
for s in snapshots {
if !s.health_ok || !s.ready_ok {
items.push(DriftItem {
kind: DriftKind::Unhealthy,
service: s.name.clone(),
details: serde_json::json!({
"health_ok": s.health_ok,
"ready_ok": s.ready_ok,
"metrics_ok": s.metrics_ok,
"base_url": s.base_url,
}),
});
}
}
// Version drift: compare build_info between services when present.
// Desired is not yet explicit; for now we flag when multiple versions exist for same service.
let mut versions_by_service: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
for s in snapshots {
if let Ok(metrics) = state
.http
.get(format!("{}/metrics", s.base_url))
.send()
.await
&& let Ok(body) = metrics.text().await
{
for bi in extract_build_info(&body) {
versions_by_service
.entry(bi.service.clone())
.or_default()
.insert(format!("{}@{}", bi.version, bi.git_sha));
}
}
}
for (svc, vs) in versions_by_service {
if vs.len() > 1 {
items.push(DriftItem {
kind: DriftKind::VersionMismatch,
service: svc,
details: serde_json::json!({ "seen": vs.into_iter().collect::<Vec<_>>() }),
});
}
}
fn ord(k: &DriftKind) -> u8 {
match k {
DriftKind::Missing => 0,
DriftKind::Extra => 1,
DriftKind::Unhealthy => 2,
DriftKind::VersionMismatch => 3,
}
}
items.sort_by(|a, b| (ord(&a.kind), &a.service).cmp(&(ord(&b.kind), &b.service)));
let mut summary: BTreeMap<String, u64> = BTreeMap::new();
for item in &items {
let k = match item.kind {
DriftKind::Missing => "missing",
DriftKind::Extra => "extra",
DriftKind::Unhealthy => "unhealthy",
DriftKind::VersionMismatch => "version_mismatch",
};
*summary.entry(k.to_string()).or_insert(0) += 1;
}
DriftResponse { summary, items }
}

View File

@@ -1,14 +1,19 @@
use crate::{
AppState, Principal,
audit::{AuditEvent, AuditStore},
config_registry::{ConfigDomain, ConfigRegistryError},
config_schemas::RoutingConfig,
fleet,
jobs::{Job, JobStatus, JobStep, JobStore},
placement::PlacementFile,
};
use std::{
collections::HashMap,
path::PathBuf,
sync::{Arc, Mutex},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use url::Url;
use uuid::Uuid;
#[derive(Clone, Default)]
@@ -34,20 +39,52 @@ impl TenantLocks {
}
}
#[derive(Clone, Default)]
pub struct ConfigLocks {
inner: Arc<Mutex<HashMap<String, Uuid>>>,
}
impl ConfigLocks {
pub fn try_lock(&self, domain: ConfigDomain, job_id: Uuid) -> bool {
let mut map = self.inner.lock().expect("config locks poisoned");
let k = domain.as_str().to_string();
if map.contains_key(&k) {
return false;
}
map.insert(k, job_id);
true
}
pub fn unlock(&self, domain: ConfigDomain, job_id: Uuid) {
let mut map = self.inner.lock().expect("config locks poisoned");
let k = domain.as_str().to_string();
if map.get(&k).copied() == Some(job_id) {
map.remove(&k);
}
}
}
#[derive(Clone)]
pub struct JobEngine {
pub jobs: JobStore,
pub audit: AuditStore,
pub tenant_locks: TenantLocks,
pub config_locks: ConfigLocks,
pub step_timeout: Duration,
}
impl JobEngine {
pub fn new(jobs: JobStore, audit: AuditStore, tenant_locks: TenantLocks) -> Self {
pub fn new(
jobs: JobStore,
audit: AuditStore,
tenant_locks: TenantLocks,
config_locks: ConfigLocks,
) -> Self {
Self {
jobs,
audit,
tenant_locks,
config_locks,
step_timeout: Duration::from_millis(500),
}
}
@@ -93,7 +130,7 @@ impl JobEngine {
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(state, inserted, Some(tenant_id), RunSpec::Drain)
.run_job(state, inserted, Some(tenant_id), None, RunSpec::Drain)
.await;
});
@@ -152,6 +189,7 @@ impl JobEngine {
state,
inserted,
Some(tenant_id),
None,
RunSpec::Migrate { runner_target },
)
.await;
@@ -160,7 +198,238 @@ impl JobEngine {
Ok(inserted)
}
async fn run_job(&self, state: AppState, job_id: Uuid, tenant_id: Option<Uuid>, spec: RunSpec) {
#[allow(clippy::too_many_arguments)]
pub fn start_config_apply(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
expected_revision: Option<u64>,
value: serde_json::Value,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![
step("preflight"),
step("validate_config"),
step("backup_config"),
step("apply_config"),
step("reload_config"),
step("verify_config"),
],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.apply", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigApply {
domain,
expected_revision,
value,
},
)
.await;
});
Ok(inserted)
}
pub fn start_config_validate(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
value: serde_json::Value,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![step("validate_config")],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.validate", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigValidate { domain, value },
)
.await;
});
Ok(inserted)
}
pub fn start_config_rollback(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![
step("rollback_config"),
step("reload_config"),
step("verify_config"),
],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.rollback", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigRollback { domain },
)
.await;
});
Ok(inserted)
}
pub fn start_platform_verify(
&self,
state: AppState,
principal: &Principal,
reason: String,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![step("preflight"), step("platform_verify")],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: "platform.verify".to_string(),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(state, inserted, None, None, RunSpec::PlatformVerify)
.await;
});
Ok(inserted)
}
async fn run_job(
&self,
state: AppState,
job_id: Uuid,
tenant_id: Option<Uuid>,
config_domain: Option<ConfigDomain>,
spec: RunSpec,
) {
self.jobs.update(job_id, |j| {
j.status = JobStatus::Running;
j.started_at_ms = Some(now_ms());
@@ -265,6 +534,9 @@ impl JobEngine {
if let Some(tid) = tenant_id {
self.tenant_locks.unlock(tid, job_id);
}
if let Some(domain) = config_domain {
self.config_locks.unlock(domain, job_id);
}
}
}
@@ -276,7 +548,22 @@ pub enum StartJobError {
#[derive(Clone)]
enum RunSpec {
Drain,
Migrate { runner_target: String },
Migrate {
runner_target: String,
},
ConfigValidate {
domain: ConfigDomain,
value: serde_json::Value,
},
ConfigApply {
domain: ConfigDomain,
expected_revision: Option<u64>,
value: serde_json::Value,
},
ConfigRollback {
domain: ConfigDomain,
},
PlatformVerify,
}
fn step(name: &str) -> JobStep {
@@ -316,9 +603,14 @@ async fn run_step(
"update_placement" => match spec {
RunSpec::Migrate { runner_target } => {
let tenant_id = tenant_id.ok_or_else(|| "missing tenant_id".to_string())?;
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
state
.placement
.update_runner_target(tenant_id, runner_target.clone())
.update_runner_target(
tenant_id,
runner_target.clone(),
entitlements.max_runners as usize,
)
.map(|_| ())
}
_ => Ok(()),
@@ -343,6 +635,400 @@ async fn run_step(
}
_ => Ok(()),
},
"validate_config" => match spec {
RunSpec::ConfigValidate { domain, value }
| RunSpec::ConfigApply { domain, value, .. } => match domain {
ConfigDomain::Routing => {
let cfg = serde_json::from_value::<RoutingConfig>(value.clone())
.map_err(|e| format!("invalid routing config: {e}"))?;
validate_routing_semantic(&cfg)?;
Ok(())
}
ConfigDomain::Placement => {
let cfg = serde_json::from_value::<PlacementFile>(value.clone())
.map_err(|e| format!("invalid placement config: {e}"))?;
validate_placement_semantic(state, &cfg)?;
Ok(())
}
},
_ => Ok(()),
},
"backup_config" => match spec {
RunSpec::ConfigApply { domain, .. } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let (cur, _) = source
.load_bytes()
.await
.map_err(|e| format!("failed to load config: {e}"))?;
let cur = cur.unwrap_or_else(|| b"null".to_vec());
let backup_key_value = serde_json::json!({ "backup": serde_json::from_slice::<serde_json::Value>(&cur).unwrap_or(serde_json::Value::Null) });
let bytes =
serde_json::to_vec_pretty(&backup_key_value).map_err(|e| e.to_string())?;
let backup_source = backup_source_for(&source.info(), *domain)
.await
.map_err(|e| format!("failed to build backup source: {e}"))?;
let _ = backup_source
.put_bytes(None, bytes)
.await
.map_err(|e| format!("failed to write backup: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"apply_config" => match spec {
RunSpec::ConfigApply {
domain,
expected_revision,
value,
} => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let bytes =
serde_json::to_vec_pretty(value).map_err(|e| format!("encode error: {e}"))?;
let _ = source
.put_bytes(*expected_revision, bytes)
.await
.map_err(|e| format!("apply failed: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"rollback_config" => match spec {
RunSpec::ConfigRollback { domain } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let backup_source = backup_source_for(&source.info(), *domain)
.await
.map_err(|e| format!("failed to build backup source: {e}"))?;
let (bytes, _) = backup_source
.load_bytes()
.await
.map_err(|e| format!("failed to load backup: {e}"))?;
let Some(bytes) = bytes else {
return Err("no backup available".to_string());
};
let v: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|e| format!("invalid backup json: {e}"))?;
let backup = v.get("backup").cloned().unwrap_or(serde_json::Value::Null);
let next =
serde_json::to_vec_pretty(&backup).map_err(|e| format!("encode error: {e}"))?;
let _ = source
.put_bytes(None, next)
.await
.map_err(|e| format!("rollback failed: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"reload_config" => Ok(()),
"verify_config" => match spec {
RunSpec::ConfigValidate { domain, .. }
| RunSpec::ConfigApply { domain, .. }
| RunSpec::ConfigRollback { domain } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let (bytes, _) = source
.load_bytes()
.await
.map_err(|e| format!("failed to load config: {e}"))?;
let bytes = bytes.unwrap_or_else(|| b"null".to_vec());
let v: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|e| format!("invalid stored json: {e}"))?;
match domain {
ConfigDomain::Routing => {
let cfg = serde_json::from_value::<RoutingConfig>(v)
.map_err(|e| format!("invalid routing config: {e}"))?;
validate_routing_semantic(&cfg)?;
Ok(())
}
ConfigDomain::Placement => {
let cfg = serde_json::from_value::<PlacementFile>(v)
.map_err(|e| format!("invalid placement config: {e}"))?;
validate_placement_semantic(state, &cfg)?;
Ok(())
}
}
}
_ => Ok(()),
},
"platform_verify" => match spec {
RunSpec::PlatformVerify => {
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
let bad: Vec<_> = snapshots
.into_iter()
.filter(|s| !(s.health_ok && s.ready_ok))
.map(|s| {
format!(
"{} health_ok={} ready_ok={}",
s.name, s.health_ok, s.ready_ok
)
})
.collect();
if !bad.is_empty() {
return Err(format!("platform verify failed: {}", bad.join("; ")));
}
Ok(())
}
_ => Ok(()),
},
_ => Ok(()),
}
}
async fn backup_source_for(
info: &crate::config_registry::ConfigSourceInfo,
domain: ConfigDomain,
) -> Result<Arc<dyn crate::config_registry::ConfigSource>, ConfigRegistryError> {
use crate::config_registry::{ConfigSource, FileSource, NatsKvSource};
match info {
crate::config_registry::ConfigSourceInfo::File { path } => Ok(Arc::new(FileSource::new(
PathBuf::from(path).with_extension(format!("{}.bak.json", domain.as_str())),
))
as Arc<dyn ConfigSource>),
crate::config_registry::ConfigSourceInfo::NatsKv { bucket, key } => {
let nats_url = std::env::var("CONTROL_CONFIG_NATS_URL").map_err(|_| {
ConfigRegistryError::Source("missing CONTROL_CONFIG_NATS_URL".to_string())
})?;
Ok(Arc::new(
NatsKvSource::connect(nats_url, bucket.clone(), format!("{key}.bak"))
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
) as Arc<dyn ConfigSource>)
}
crate::config_registry::ConfigSourceInfo::Fixed => Err(ConfigRegistryError::Source(
"no backups for fixed source".to_string(),
)),
}
}
fn validate_routing_semantic(cfg: &RoutingConfig) -> Result<(), String> {
let shard_maps = [
("aggregate_shards", &cfg.aggregate_shards),
("projection_shards", &cfg.projection_shards),
("runner_shards", &cfg.runner_shards),
];
for (name, map) in shard_maps {
for (shard_id, endpoints) in map {
if endpoints.is_empty() {
return Err(format!("{name}[{shard_id}] has no endpoints"));
}
for ep in endpoints {
let u = Url::parse(ep)
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
if u.scheme() != "http" && u.scheme() != "https" {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
));
}
if u.host_str().is_none() {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must include host"
));
}
}
}
}
// Ensure placement references known shard ids.
let placements = [
(
"aggregate_placement",
&cfg.aggregate_placement,
&cfg.aggregate_shards,
),
(
"projection_placement",
&cfg.projection_placement,
&cfg.projection_shards,
),
(
"runner_placement",
&cfg.runner_placement,
&cfg.runner_shards,
),
];
for (pname, pmap, shards) in placements {
for (tenant, shard_id) in pmap {
if shard_id.trim().is_empty() {
return Err(format!("{pname}[{tenant}] shard_id is empty"));
}
if !shards.contains_key(shard_id) {
return Err(format!(
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
));
}
}
}
Ok(())
}
fn validate_placement_semantic(state: &AppState, cfg: &PlacementFile) -> Result<(), String> {
if !state.billing_enforcement_enabled {
return Ok(());
}
let mut tenant_counts = std::collections::HashMap::new();
let kinds = [
("aggregate_placement", cfg.aggregate_placement.as_ref()),
("projection_placement", cfg.projection_placement.as_ref()),
("runner_placement", cfg.runner_placement.as_ref()),
];
for (kind_name, k) in kinds {
let Some(k) = k else { continue };
for p in &k.placements {
if p.targets.is_empty() {
return Err(format!("{kind_name} tenant {} has no targets", p.tenant_id));
}
if p.targets.iter().any(|t| t.trim().is_empty()) {
return Err(format!(
"{kind_name} tenant {} has empty target",
p.tenant_id
));
}
let entry = tenant_counts.entry(p.tenant_id).or_insert((0, 0)); // (deployments, runners)
if kind_name == "runner_placement" {
entry.1 += p.targets.len();
} else {
entry.0 += p.targets.len();
}
}
}
for (tenant_id, (deployments, runners)) in tenant_counts {
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
if deployments > entitlements.max_deployments as usize {
return Err(format!(
"tenant {} exceeds max_deployments limit ({} > {})",
tenant_id, deployments, entitlements.max_deployments
));
}
if runners > entitlements.max_runners as usize {
return Err(format!(
"tenant {} exceeds max_runners limit ({} > {})",
tenant_id, runners, entitlements.max_runners
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::billing::{BillingStore, Plan, SubscriptionStatus, TenantBillingState};
use crate::placement::{PlacementFile, PlacementKind, TenantPlacement};
fn mock_state(billing: BillingStore) -> AppState {
let handle = crate::get_test_prometheus_handle();
let root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
AppState {
prometheus: handle,
auth: crate::AuthConfig {
hs256_secret: Some(b"secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: crate::placement::PlacementStore::new(
std::env::temp_dir().join("placement.json"),
),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: crate::config_registry::ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: crate::swarm::SwarmStore::new(root.join("swarm/dev.json")),
docs: None,
}
}
#[test]
fn test_validate_placement_limits() {
let tenant_id = Uuid::new_v4();
let billing_path =
std::env::temp_dir().join(format!("billing-unit-{}.json", Uuid::new_v4()));
let billing = BillingStore::new(billing_path.clone());
let state = mock_state(billing.clone());
// 1. Free plan (default): max_deployments=1, max_runners=1
let cfg = PlacementFile {
revision: Some("v1".to_string()),
aggregate_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["a1".to_string()],
}],
}),
projection_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["p1".to_string()],
}],
}),
runner_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["r1".to_string()],
}],
}),
};
// aggregate(1) + projection(1) = 2 deployments. Limit is 1. Should fail.
let err = validate_placement_semantic(&state, &cfg).unwrap_err();
assert!(err.contains("exceeds max_deployments limit"));
// 2. Reduce to 1 deployment
let cfg2 = PlacementFile {
revision: Some("v2".to_string()),
aggregate_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["a1".to_string()],
}],
}),
projection_placement: None,
runner_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["r1".to_string()],
}],
}),
};
validate_placement_semantic(&state, &cfg2).unwrap();
// 3. Upgrade to Pro: max_deployments=10, max_runners=10
billing
.update_tenant_state(
tenant_id,
TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::Active),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 100,
},
)
.unwrap();
// Now the first cfg should pass
validate_placement_semantic(&state, &cfg).unwrap();
let _ = std::fs::remove_file(billing_path);
}
}

View File

@@ -1,14 +1,22 @@
mod admin;
mod audit;
mod auth;
pub mod billing;
mod build_info;
pub mod config_registry;
mod config_schemas;
mod deployments;
mod documents;
mod drift;
mod fleet;
mod job_engine;
mod jobs;
mod placement;
pub mod s3_docs;
mod swarm;
use std::sync::Arc;
pub use audit::AuditStore;
pub use auth::{AuthConfig, Principal};
use axum::{
@@ -20,8 +28,10 @@ use axum::{
routing::get,
};
pub use build_info::{BuildInfo, extract_build_info};
pub use config_registry::{ConfigDomain, ConfigRegistry};
pub use deployments::{DeployAnnotationArgs, GrafanaAnnotation, build_grafana_deploy_annotation};
pub use fleet::FleetService;
pub use job_engine::ConfigLocks;
pub use job_engine::TenantLocks;
pub use jobs::JobStore;
use metrics_exporter_prometheus::PrometheusHandle;
@@ -40,10 +50,16 @@ pub struct AppState {
pub jobs: JobStore,
pub audit: AuditStore,
pub tenant_locks: TenantLocks,
pub config_locks: ConfigLocks,
pub http: reqwest::Client,
pub placement: PlacementStore,
pub billing: billing::BillingStore,
pub billing_provider: Arc<dyn billing::BillingProvider>,
pub billing_enforcement_enabled: bool,
pub config: ConfigRegistry,
pub fleet_services: Vec<FleetService>,
pub swarm: SwarmStore,
pub docs: Option<s3_docs::DocsStore>,
}
#[derive(Clone, Debug)]
@@ -93,13 +109,18 @@ pub fn build_app(state: AppState) -> Router {
},
);
let admin =
admin::admin_router().layer(from_fn_with_state(state.clone(), auth::auth_middleware));
let admin = admin::admin_router()
.merge(documents::router())
.layer(from_fn_with_state(state.clone(), auth::auth_middleware));
Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.route("/metrics", get(metrics))
.route(
"/admin/v1/billing/webhooks/{provider}",
axum::routing::post(billing::webhook),
)
.nest("/admin/v1", admin)
.with_state(state)
.layer(trace)
@@ -167,25 +188,46 @@ async fn request_id_middleware(mut req: Request<axum::body::Body>, next: Next) -
res
}
#[cfg(test)]
static TEST_PROMETHEUS_HANDLE: std::sync::OnceLock<PrometheusHandle> = std::sync::OnceLock::new();
#[cfg(test)]
pub(crate) fn get_test_prometheus_handle() -> PrometheusHandle {
TEST_PROMETHEUS_HANDLE
.get_or_init(|| {
metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.unwrap_or_else(|_| {
// This can happen if another test already installed it.
// We might not get the ACTUAL handle to the global recorder here if we don't share it,
// but for tests it's usually fine to have a dummy one if we are not asserting on metrics.
metrics_exporter_prometheus::PrometheusBuilder::new()
.build()
.expect("failed to build prometheus recorder")
.0
.handle()
})
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config_registry::{FileSource, FixedSource};
use crate::jobs::JobStatus;
use axum::{
body::Body,
http::{Request, StatusCode, header},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use metrics_exporter_prometheus::PrometheusBuilder;
use serde::Serialize;
use std::fs;
use std::path::PathBuf;
use std::sync::OnceLock;
use std::sync::Arc;
use tower::ServiceExt;
use uuid::Uuid;
static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
#[derive(Serialize)]
struct TestClaims {
sub: String,
@@ -199,15 +241,10 @@ mod tests {
}
fn test_app_with_fleet(fleet_services: Vec<FleetService>) -> Router {
let handle = HANDLE
.get_or_init(|| {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder")
})
.clone();
let handle = get_test_prometheus_handle();
let placement_path = temp_placement_file();
let root = repo_root();
build_app(AppState {
prometheus: handle,
@@ -217,10 +254,23 @@ mod tests {
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(placement_path),
billing: crate::billing::BillingStore::new(
std::env::temp_dir().join(format!("billing-test-{}.json", Uuid::new_v4())),
),
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services,
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
})
}
@@ -234,14 +284,14 @@ mod tests {
fn temp_placement_file() -> PathBuf {
let root = repo_root();
let src = root.join("placement/dev.json");
let src = root.join("config/placement/dev.json");
let mut dst = std::env::temp_dir();
dst.push(format!(
"cloudlysis-control-placement-{}-{}.json",
std::process::id(),
Uuid::new_v4()
));
let raw = fs::read_to_string(src).expect("missing placement/dev.json");
let raw = fs::read_to_string(src).expect("missing config/placement/dev.json");
fs::write(&dst, raw).expect("failed to write temp placement file");
dst
}
@@ -689,4 +739,467 @@ mod tests {
&serde_json::json!(["preflight", "drain", "update_placement", "reload", "verify"])
);
}
#[tokio::test]
async fn billing_returns_not_configured_by_default() {
let token = make_token(&["control:read"]);
let tenant_id = Uuid::new_v4();
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(false));
assert_eq!(
v.get("entitlements")
.unwrap()
.get("max_deployments")
.unwrap(),
&serde_json::json!(1)
);
}
#[tokio::test]
async fn billing_returns_configured_state() {
let token = make_token(&["control:read"]);
let tenant_id = Uuid::new_v4();
let handle = get_test_prometheus_handle();
let billing_path =
std::env::temp_dir().join(format!("billing-test-cfg-{}.json", Uuid::new_v4()));
let billing = crate::billing::BillingStore::new(billing_path.clone());
billing
.update_tenant_state(
tenant_id,
crate::billing::TenantBillingState {
provider: "stripe".to_string(),
provider_customer_id: Some("cus_123".to_string()),
provider_subscription_id: Some("sub_123".to_string()),
provider_checkout_session_id: None,
status: Some(crate::billing::SubscriptionStatus::Active),
plan: Some(crate::billing::Plan::Pro),
current_period_end: Some("2026-04-30T00:00:00Z".to_string()),
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 1234567890,
},
)
.unwrap();
let root = repo_root();
let app = build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(temp_placement_file()),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
});
let res = app
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
assert_eq!(
v.get("entitlements")
.unwrap()
.get("max_deployments")
.unwrap(),
&serde_json::json!(10)
);
let _ = std::fs::remove_file(billing_path);
}
#[tokio::test]
async fn checkout_returns_mock_url() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
let body = serde_json::json!({
"plan": "pro",
"return_path": "/custom-return"
});
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(
v.get("url").unwrap(),
&serde_json::json!(format!("https://mock.stripe.com/checkout/{}", tenant_id))
);
}
#[tokio::test]
async fn checkout_fails_if_already_active() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
// Setup app with active subscription
let billing_path =
std::env::temp_dir().join(format!("billing-test-active-{}.json", Uuid::new_v4()));
let billing = crate::billing::BillingStore::new(billing_path.clone());
billing
.update_tenant_state(
tenant_id,
crate::billing::TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(crate::billing::SubscriptionStatus::Active),
plan: Some(crate::billing::Plan::Pro),
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 0,
},
)
.unwrap();
let handle = get_test_prometheus_handle();
let root = repo_root();
let app = build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(temp_placement_file()),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
});
let body = serde_json::json!({ "plan": "pro" });
let res = app
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::CONFLICT);
let _ = std::fs::remove_file(billing_path);
}
#[tokio::test]
async fn portal_returns_mock_url() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/portal"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(
v.get("url").unwrap(),
&serde_json::json!(format!("https://mock.stripe.com/portal/{}", tenant_id))
);
}
#[tokio::test]
async fn webhook_updates_state_idempotently() {
let tenant_id = Uuid::new_v4();
let event_id = "evt_123".to_string();
let app = test_app();
let event = crate::billing::BillingEvent::SubscriptionCreated {
tenant_id,
event_id: event_id.clone(),
provider_customer_id: "cus_123".to_string(),
provider_subscription_id: "sub_123".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Pro,
current_period_end: "2026-04-30T00:00:00Z".to_string(),
ts_ms: 1000,
};
let body = serde_json::to_string(&event).unwrap();
// 1. Send webhook
let res = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.clone()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// 2. Verify state
let token = make_token(&["control:read"]);
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
// 3. Send same webhook again (idempotency)
let res = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn webhook_ignores_stale_events() {
let tenant_id = Uuid::new_v4();
let app = test_app();
// 1. Send recent event (ts=2000)
let event1 = crate::billing::BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: "evt_new".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Enterprise,
current_period_end: "2026-05-30T00:00:00Z".to_string(),
cancel_at_period_end: false,
ts_ms: 2000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.body(Body::from(serde_json::to_string(&event1).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 2. Send stale event (ts=1000)
let event2 = crate::billing::BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: "evt_old".to_string(),
status: crate::billing::SubscriptionStatus::PastDue,
plan: crate::billing::Plan::Pro,
current_period_end: "2026-04-30T00:00:00Z".to_string(),
cancel_at_period_end: false,
ts_ms: 1000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.body(Body::from(serde_json::to_string(&event2).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 3. Verify state is still Enterprise
let token = make_token(&["control:read"]);
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("enterprise"));
}
#[tokio::test]
async fn s3_docs_requires_pro_plan() {
let token = make_token(&["control:read", "control:write"]);
let tenant_id = Uuid::new_v4();
let app = test_app();
// 1. Try to list docs (Free plan by default)
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED);
// 2. Update to Pro plan via webhook
let event = crate::billing::BillingEvent::SubscriptionCreated {
tenant_id,
event_id: "evt_pro".to_string(),
provider_customer_id: "cus_pro".to_string(),
provider_subscription_id: "sub_pro".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Pro,
current_period_end: "2099-01-01T00:00:00Z".to_string(),
ts_ms: 2000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&event).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 3. Try to list docs again (Should fail with 503 if S3 not configured in tests, or 200/502 if it is)
// In test_app(), docs is None by default.
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
// Since docs is None in test_app(), it returns SERVICE_UNAVAILABLE (503) AFTER passing the entitlement check.
// If it was still PAYMENT_REQUIRED, it would return 402.
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}

View File

@@ -1,6 +1,8 @@
use clap::Parser;
use metrics_exporter_prometheus::PrometheusBuilder;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tracing_subscriber::EnvFilter;
#[derive(Parser, Debug)]
@@ -33,16 +35,32 @@ async fn main() {
.build()
.expect("failed to build http client");
let placement_path = std::env::var("CONTROL_PLACEMENT_PATH")
let placement_path: PathBuf = std::env::var("CONTROL_PLACEMENT_PATH")
.ok()
.unwrap_or_else(|| "placement/dev.json".to_string())
.unwrap_or_else(|| "config/placement/dev.json".to_string())
.into();
let swarm_path = std::env::var("CONTROL_SWARM_STATE_PATH")
let billing_path: PathBuf = std::env::var("CONTROL_BILLING_STATE_PATH")
.ok()
.unwrap_or_else(|| "swarm/dev.json".to_string())
.unwrap_or_else(|| "billing/dev.json".to_string())
.into();
let routing_path: PathBuf = std::env::var("CONTROL_ROUTING_PATH")
.ok()
.unwrap_or_else(|| "config/routing/dev.json".to_string())
.into();
let swarm_mode = std::env::var("CONTROL_SWARM_MODE").ok();
let swarm = if swarm_mode.as_deref() == Some("docker") {
api::SwarmStore::new_docker_cli()
} else {
let swarm_path: PathBuf = std::env::var("CONTROL_SWARM_STATE_PATH")
.ok()
.unwrap_or_else(|| "swarm/dev.json".to_string())
.into();
api::SwarmStore::new(swarm_path)
};
let self_url = std::env::var("CONTROL_SELF_URL")
.ok()
.unwrap_or_else(|| "http://127.0.0.1:8080".to_string());
@@ -55,7 +73,70 @@ async fn main() {
fleet_services.extend(parse_fleet_services(&spec));
}
let app = api::build_app(api::AppState {
let docs_cfg =
api::s3_docs::DocsConfig::from_env().expect("missing S3 document storage configuration");
let docs = api::s3_docs::DocsStore::new(docs_cfg)
.await
.expect("failed to initialize S3 document storage client");
let config = {
let routing = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
std::env::var("CONTROL_ROUTING_NATS_URL"),
std::env::var("CONTROL_ROUTING_NATS_BUCKET"),
std::env::var("CONTROL_ROUTING_NATS_KEY"),
) {
Some(Arc::new(
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
.await
.expect("failed to connect to routing config nats kv"),
) as Arc<dyn api::config_registry::ConfigSource>)
} else {
Some(
Arc::new(api::config_registry::FileSource::new(routing_path))
as Arc<dyn api::config_registry::ConfigSource>,
)
};
let placement = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
std::env::var("CONTROL_PLACEMENT_NATS_URL"),
std::env::var("CONTROL_PLACEMENT_NATS_BUCKET"),
std::env::var("CONTROL_PLACEMENT_NATS_KEY"),
) {
Some(Arc::new(
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
.await
.expect("failed to connect to placement config nats kv"),
) as Arc<dyn api::config_registry::ConfigSource>)
} else {
Some(Arc::new(api::config_registry::FileSource::new(
placement_path.clone(),
))
as Arc<dyn api::config_registry::ConfigSource>)
};
api::ConfigRegistry::new(routing, placement)
};
let billing_provider: Arc<dyn api::billing::BillingProvider> =
match std::env::var("CONTROL_BILLING_PROVIDER").as_deref() {
Ok("stripe") => {
let secret_key = std::env::var("CONTROL_STRIPE_SECRET_KEY")
.expect("CONTROL_STRIPE_SECRET_KEY required for stripe provider");
let price_pro = std::env::var("CONTROL_STRIPE_PRICE_ID_PRO")
.expect("CONTROL_STRIPE_PRICE_ID_PRO required for stripe provider");
let price_enterprise = std::env::var("CONTROL_STRIPE_PRICE_ID_ENTERPRISE")
.expect("CONTROL_STRIPE_PRICE_ID_ENTERPRISE required for stripe provider");
Arc::new(api::billing::StripeProvider {
secret_key,
price_pro,
price_enterprise,
})
}
_ => Arc::new(api::billing::MockProvider),
};
let state = api::AppState {
prometheus: recorder,
auth: api::AuthConfig {
hs256_secret: std::env::var("CONTROL_GATEWAY_JWT_HS256_SECRET")
@@ -65,11 +146,25 @@ async fn main() {
jobs: api::JobStore::default(),
audit: api::AuditStore::default(),
tenant_locks: api::TenantLocks::default(),
config_locks: api::ConfigLocks::default(),
http,
placement: api::PlacementStore::new(placement_path),
billing: api::billing::BillingStore::new(billing_path),
billing_provider,
billing_enforcement_enabled: std::env::var("CONTROL_BILLING_ENFORCEMENT_ENABLED")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(false),
config,
fleet_services,
swarm: api::SwarmStore::new(swarm_path),
});
swarm,
docs: Some(docs),
};
// Spawn reconciliation loop
tokio::spawn(api::billing::run_reconciliation_loop(state.clone()));
let app = api::build_app(state);
let listener = tokio::net::TcpListener::bind(args.addr)
.await

View File

@@ -157,6 +157,7 @@ impl PlacementStore {
&self,
tenant_id: Uuid,
runner_target: String,
max_runners: usize,
) -> Result<String, String> {
let mut inner = self.inner.write().expect("placement lock poisoned");
inner.reload_if_changed();
@@ -178,8 +179,17 @@ impl PlacementStore {
.iter_mut()
.find(|p| p.tenant_id == tenant_id)
{
// If already at or above limit, and we are adding a NEW target (not replacing), it would fail.
// But here update_runner_target REPLACES the target list with a single target for now.
// If in the future we want to append, we check targets.len().
if 1 > max_runners {
return Err(format!("exceeds max_runners limit of {}", max_runners));
}
existing.targets = vec![runner_target];
} else {
if 1 > max_runners {
return Err(format!("exceeds max_runners limit of {}", max_runners));
}
runner.placements.push(TenantPlacement {
tenant_id,
targets: vec![runner_target],

508
control/api/src/s3_docs.rs Normal file
View File

@@ -0,0 +1,508 @@
use aws_config::Region;
use aws_credential_types::Credentials;
use aws_sdk_s3::presigning::PresigningConfig;
use aws_sdk_s3::types::BucketCannedAcl;
use aws_sdk_s3::{Client, config::Builder as S3ConfigBuilder};
use sha2::Digest;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct DocsConfig {
pub endpoint: String,
pub public_endpoint: Option<String>,
pub region: String,
pub access_key_id: String,
pub secret_access_key: String,
pub force_path_style: bool,
pub insecure: bool,
pub buckets: Vec<String>,
pub prefix: String,
}
impl DocsConfig {
pub fn from_env() -> Result<Self, String> {
fn get(name: &str) -> Option<String> {
std::env::var(name)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn get_secret(name: &str, file_name: &str) -> Result<Option<String>, String> {
if let Some(path) = get(file_name) {
let raw = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
let v = raw.trim().to_string();
if v.is_empty() {
return Ok(None);
}
return Ok(Some(v));
}
Ok(get(name))
}
let endpoint = get("CONTROL_S3_ENDPOINT")
.or_else(|| get("S3_ENDPOINT"))
.ok_or_else(|| "Missing CONTROL_S3_ENDPOINT".to_string())?;
let public_endpoint =
get("CONTROL_S3_PUBLIC_ENDPOINT").or_else(|| get("S3_PUBLIC_ENDPOINT"));
let region = get("CONTROL_S3_REGION")
.or_else(|| get("S3_REGION"))
.unwrap_or_else(|| "us-east-1".to_string());
let access_key_id =
get_secret("CONTROL_S3_ACCESS_KEY_ID", "CONTROL_S3_ACCESS_KEY_ID_FILE")?
.or_else(|| {
get_secret("S3_ACCESS_KEY_ID", "S3_ACCESS_KEY_ID_FILE")
.ok()
.flatten()
})
.ok_or_else(|| "Missing CONTROL_S3_ACCESS_KEY_ID".to_string())?;
let secret_access_key = get_secret(
"CONTROL_S3_SECRET_ACCESS_KEY",
"CONTROL_S3_SECRET_ACCESS_KEY_FILE",
)?
.or_else(|| {
get_secret("S3_SECRET_ACCESS_KEY", "S3_SECRET_ACCESS_KEY_FILE")
.ok()
.flatten()
})
.ok_or_else(|| "Missing CONTROL_S3_SECRET_ACCESS_KEY".to_string())?;
let force_path_style = get("CONTROL_S3_FORCE_PATH_STYLE")
.or_else(|| get("S3_FORCE_PATH_STYLE"))
.as_deref()
.map(|v| v == "true" || v == "1")
.unwrap_or(true);
let insecure = get("CONTROL_S3_INSECURE")
.or_else(|| get("S3_INSECURE"))
.as_deref()
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let bucket_raw = get("CONTROL_S3_BUCKET_DOCS")
.or_else(|| get("S3_BUCKET_DOCS"))
.ok_or_else(|| "Missing CONTROL_S3_BUCKET_DOCS".to_string())?;
let buckets: Vec<String> = bucket_raw
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if buckets.is_empty() {
return Err("Missing CONTROL_S3_BUCKET_DOCS".to_string());
}
let prefix = get("CONTROL_S3_PREFIX_DOCS")
.or_else(|| get("S3_PREFIX_DOCS"))
.unwrap_or_else(|| "docs/".to_string());
let prefix = if prefix.ends_with('/') {
prefix
} else {
format!("{prefix}/")
};
// SECURITY: `*_INSECURE=true` is intended for local MinIO setups that use plain HTTP.
// We currently do not disable TLS certificate verification for HTTPS endpoints.
if insecure && endpoint.trim_start().starts_with("https://") {
return Err(
"CONTROL_S3_INSECURE=true is not supported with https:// endpoints (TLS verification is not disabled). Use http:// for local MinIO, or set CONTROL_S3_INSECURE=false for production."
.to_string(),
);
}
Ok(Self {
endpoint,
public_endpoint,
region,
access_key_id,
secret_access_key,
force_path_style,
insecure,
buckets,
prefix,
})
}
}
#[derive(Clone)]
pub struct DocsStore {
cfg: DocsConfig,
client: Client,
presign_client: Client,
}
impl DocsStore {
pub async fn new(cfg: DocsConfig) -> Result<Self, String> {
let creds = Credentials::new(
cfg.access_key_id.clone(),
cfg.secret_access_key.clone(),
None,
None,
"static",
);
let shared = aws_config::from_env()
.region(Region::new(cfg.region.clone()))
.credentials_provider(creds.clone())
.endpoint_url(cfg.endpoint.clone())
.load()
.await;
let s3_conf = S3ConfigBuilder::from(&shared)
.force_path_style(cfg.force_path_style)
.build();
let client = Client::from_conf(s3_conf);
let presign_endpoint = cfg
.public_endpoint
.clone()
.unwrap_or_else(|| cfg.endpoint.clone());
let presign_shared = aws_config::from_env()
.region(Region::new(cfg.region.clone()))
.credentials_provider(creds)
.endpoint_url(presign_endpoint)
.load()
.await;
let presign_conf = S3ConfigBuilder::from(&presign_shared)
.force_path_style(cfg.force_path_style)
.build();
let presign_client = Client::from_conf(presign_conf);
Ok(Self {
cfg,
client,
presign_client,
})
}
pub fn key_for(
&self,
tenant_id: &str,
doc_type: &str,
doc_id: &str,
filename: &str,
) -> Result<String, String> {
validate_segment("tenant_id", tenant_id)?;
validate_segment("doc_type", doc_type)?;
validate_segment("doc_id", doc_id)?;
validate_filename(filename)?;
Ok(format!(
"{}{}/{}/{}/{}",
self.cfg.prefix, tenant_id, doc_type, doc_id, filename
))
}
pub fn prefix(&self) -> &str {
self.cfg.prefix.as_str()
}
pub fn buckets(&self) -> &[String] {
self.cfg.buckets.as_slice()
}
fn bucket_for_tenant(&self, tenant_id: &str) -> &str {
// Deterministic sharding across buckets. Note: if the bucket list changes, the mapping changes.
// For production, set the full planned bucket set up-front (e.g. `-0,-1,-2`) to keep mapping stable.
let n = self.cfg.buckets.len();
if n == 1 {
return self.cfg.buckets[0].as_str();
}
let mut hasher = sha2::Sha256::new();
hasher.update(tenant_id.as_bytes());
let digest = hasher.finalize();
let mut b = [0u8; 8];
b.copy_from_slice(&digest[..8]);
let v = u64::from_be_bytes(b);
let idx = (v as usize) % n;
self.cfg.buckets[idx].as_str()
}
pub fn content_hash_sha256_hex(bytes: &[u8]) -> String {
let mut hasher = sha2::Sha256::new();
hasher.update(bytes);
let digest = hasher.finalize();
let mut out = String::with_capacity(digest.len() * 2);
for b in digest {
use std::fmt::Write;
let _ = write!(&mut out, "{:02x}", b);
}
out
}
pub async fn put_for_tenant(
&self,
tenant_id: &str,
key: &str,
bytes: Vec<u8>,
content_type: Option<String>,
) -> Result<(), String> {
let mut req = self
.client
.put_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.body(aws_sdk_s3::primitives::ByteStream::from(bytes));
if let Some(ct) = content_type {
req = req.content_type(ct);
}
req.send().await.map_err(|e| e.to_string())?;
Ok(())
}
pub async fn get_bytes_for_tenant(
&self,
tenant_id: &str,
key: &str,
) -> Result<(Vec<u8>, Option<String>), String> {
let out = self
.client
.get_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.send()
.await
.map_err(|e| e.to_string())?;
let ct = out.content_type().map(|s| s.to_string());
let bytes = out
.body
.collect()
.await
.map_err(|e| e.to_string())?
.into_bytes()
.to_vec();
Ok((bytes, ct))
}
pub async fn delete_for_tenant(&self, tenant_id: &str, key: &str) -> Result<(), String> {
self.client
.delete_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.send()
.await
.map_err(|e| e.to_string())?;
Ok(())
}
pub async fn list_for_tenant(
&self,
tenant_id: &str,
prefix: &str,
) -> Result<Vec<DocObject>, String> {
let out = self
.client
.list_objects_v2()
.bucket(self.bucket_for_tenant(tenant_id))
.prefix(prefix)
.send()
.await
.map_err(|e| e.to_string())?;
let mut items = Vec::new();
for o in out.contents() {
if let Some(key) = o.key() {
items.push(DocObject {
key: key.to_string(),
size: o.size().unwrap_or(0),
last_modified: o.last_modified().map(|d| d.to_string()),
});
}
}
Ok(items)
}
pub async fn ensure_buckets_exist(&self) -> Result<(), String> {
for bucket in &self.cfg.buckets {
let head = self.client.head_bucket().bucket(bucket).send().await;
if head.is_ok() {
continue;
}
self.client
.create_bucket()
.bucket(bucket)
.acl(BucketCannedAcl::Private)
.send()
.await
.map_err(|e| e.to_string())?;
}
Ok(())
}
pub async fn presign_put_for_tenant(
&self,
tenant_id: &str,
key: &str,
content_type: Option<String>,
expires: Duration,
) -> Result<String, String> {
let mut req = self
.presign_client
.put_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key);
if let Some(ct) = content_type {
req = req.content_type(ct);
}
let presigned = req
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
.await
.map_err(|e| e.to_string())?;
Ok(presigned.uri().to_string())
}
pub async fn presign_get_for_tenant(
&self,
tenant_id: &str,
key: &str,
expires: Duration,
) -> Result<String, String> {
let req = self
.presign_client
.get_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key);
let presigned = req
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
.await
.map_err(|e| e.to_string())?;
Ok(presigned.uri().to_string())
}
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct DocObject {
pub key: String,
pub size: i64,
pub last_modified: Option<String>,
}
fn validate_segment(name: &str, value: &str) -> Result<(), String> {
if value.is_empty() {
return Err(format!("{name} is required"));
}
if value.len() > 128 {
return Err(format!("{name} too long"));
}
if value.contains('/') || value.contains('\\') {
return Err(format!("{name} contains invalid characters"));
}
if value.contains("..") {
return Err(format!("{name} contains invalid characters"));
}
Ok(())
}
fn validate_filename(value: &str) -> Result<(), String> {
validate_segment("filename", value)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap()
}
#[test]
fn config_from_env_parses_expected_fields() {
let _guard = env_lock();
unsafe {
std::env::set_var("CONTROL_S3_ENDPOINT", "http://minio:9000");
std::env::set_var("CONTROL_S3_REGION", "us-east-1");
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "minioadmin");
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "minioadmin");
std::env::set_var("CONTROL_S3_BUCKET_DOCS", "cloudlysis-docs");
std::env::set_var("CONTROL_S3_PREFIX_DOCS", "docs/");
std::env::set_var("CONTROL_S3_FORCE_PATH_STYLE", "true");
std::env::set_var("CONTROL_S3_INSECURE", "true");
}
let cfg = DocsConfig::from_env().unwrap();
assert_eq!(cfg.endpoint, "http://minio:9000");
assert_eq!(cfg.buckets, vec!["cloudlysis-docs".to_string()]);
assert_eq!(cfg.prefix, "docs/");
assert!(cfg.force_path_style);
assert!(cfg.insecure);
unsafe {
std::env::remove_var("CONTROL_S3_ENDPOINT");
std::env::remove_var("CONTROL_S3_REGION");
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
std::env::remove_var("CONTROL_S3_PREFIX_DOCS");
std::env::remove_var("CONTROL_S3_FORCE_PATH_STYLE");
std::env::remove_var("CONTROL_S3_INSECURE");
}
}
#[test]
fn config_rejects_insecure_with_https_endpoint() {
let _guard = env_lock();
unsafe {
std::env::set_var("CONTROL_S3_ENDPOINT", "https://s3.example.com");
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "a");
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "b");
std::env::set_var(
"CONTROL_S3_BUCKET_DOCS",
"cloudlysis-docs-0,cloudlysis-docs-1",
);
std::env::set_var("CONTROL_S3_INSECURE", "true");
}
let err = DocsConfig::from_env().unwrap_err();
assert!(
err.contains("CONTROL_S3_INSECURE=true") && err.contains("https://"),
"unexpected error: {err}"
);
unsafe {
std::env::remove_var("CONTROL_S3_ENDPOINT");
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
std::env::remove_var("CONTROL_S3_INSECURE");
}
}
#[tokio::test]
async fn key_scheme_is_stable() {
let cfg = DocsConfig {
endpoint: "http://minio:9000".to_string(),
public_endpoint: None,
region: "us-east-1".to_string(),
access_key_id: "x".to_string(),
secret_access_key: "y".to_string(),
force_path_style: true,
insecure: true,
buckets: vec![
"cloudlysis-docs-0".to_string(),
"cloudlysis-docs-1".to_string(),
],
prefix: "docs/".to_string(),
};
let store = DocsStore::new(cfg).await.unwrap();
let key = store
.key_for("tenant-a", "deployments", "v1", "bundle.tar.gz")
.unwrap();
assert_eq!(key, "docs/tenant-a/deployments/v1/bundle.tar.gz");
}
#[tokio::test]
async fn key_scheme_rejects_invalid_segments() {
let cfg = DocsConfig {
endpoint: "http://minio:9000".to_string(),
public_endpoint: None,
region: "us-east-1".to_string(),
access_key_id: "x".to_string(),
secret_access_key: "y".to_string(),
force_path_style: true,
insecure: true,
buckets: vec!["cloudlysis-docs".to_string()],
prefix: "docs/".to_string(),
};
let store = DocsStore::new(cfg).await.unwrap();
assert!(store.key_for("t/a", "x", "y", "z").is_err());
assert!(store.key_for("t", "x", "../y", "z").is_err());
assert!(store.key_for("t", "x", "y", "a/b").is_err());
}
}

View File

@@ -28,31 +28,49 @@ pub struct SwarmStateFile {
#[derive(Clone)]
pub struct SwarmStore {
path: std::path::PathBuf,
inner: SwarmStoreInner,
}
#[derive(Clone)]
enum SwarmStoreInner {
File { path: std::path::PathBuf },
DockerCli,
}
impl SwarmStore {
pub fn new(path: std::path::PathBuf) -> Self {
Self { path }
Self {
inner: SwarmStoreInner::File { path },
}
}
pub fn new_docker_cli() -> Self {
Self {
inner: SwarmStoreInner::DockerCli,
}
}
pub fn list_services(&self) -> Vec<SwarmService> {
self.load().map(|s| s.services).unwrap_or_default()
match &self.inner {
SwarmStoreInner::File { path } => {
load_state(path).map(|s| s.services).unwrap_or_default()
}
SwarmStoreInner::DockerCli => list_services_docker_cli().unwrap_or_default(),
}
}
pub fn list_tasks(&self, service_name: &str) -> Vec<SwarmTask> {
self.load()
.map(|s| {
s.tasks
.into_iter()
.filter(|t| t.service == service_name)
.collect()
})
.unwrap_or_default()
}
fn load(&self) -> Option<SwarmStateFile> {
load_state(&self.path)
match &self.inner {
SwarmStoreInner::File { path } => load_state(path)
.map(|s| {
s.tasks
.into_iter()
.filter(|t| t.service == service_name)
.collect()
})
.unwrap_or_default(),
SwarmStoreInner::DockerCli => list_tasks_docker_cli(service_name).unwrap_or_default(),
}
}
}
@@ -60,3 +78,120 @@ fn load_state(path: &Path) -> Option<SwarmStateFile> {
let raw = fs::read_to_string(path).ok()?;
serde_json::from_str(&raw).ok()
}
fn list_services_docker_cli() -> Result<Vec<SwarmService>, String> {
let out = std::process::Command::new("docker")
.args(["service", "ls", "--format", "{{json .}}"])
.output()
.map_err(|e| format!("docker exec failed: {e}"))?;
if !out.status.success() {
return Err(format!(
"docker service ls failed: {}",
String::from_utf8_lossy(&out.stderr)
));
}
#[derive(Deserialize)]
struct ServiceRow {
#[serde(rename = "Name")]
name: String,
#[serde(rename = "Image")]
image: Option<String>,
#[serde(rename = "Mode")]
mode: Option<String>,
#[serde(rename = "Replicas")]
replicas: Option<String>,
#[serde(rename = "UpdatedAt")]
updated_at: Option<String>,
}
let mut services = Vec::new();
for line in String::from_utf8_lossy(&out.stdout).lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let row: ServiceRow =
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
services.push(SwarmService {
name: row.name,
image: row.image,
mode: row.mode,
replicas: row.replicas,
updated_at: row.updated_at,
});
}
Ok(services)
}
fn list_tasks_docker_cli(service_name: &str) -> Result<Vec<SwarmTask>, String> {
let out = std::process::Command::new("docker")
.args([
"service",
"ps",
service_name,
"--no-trunc",
"--format",
"{{json .}}",
])
.output()
.map_err(|e| format!("docker exec failed: {e}"))?;
if !out.status.success() {
return Err(format!(
"docker service ps failed: {}",
String::from_utf8_lossy(&out.stderr)
));
}
#[derive(Deserialize)]
struct TaskRow {
#[serde(rename = "ID")]
id: String,
#[serde(rename = "Name")]
name: Option<String>,
#[serde(rename = "Node")]
node: Option<String>,
#[serde(rename = "DesiredState")]
desired_state: Option<String>,
#[serde(rename = "CurrentState")]
current_state: Option<String>,
#[serde(rename = "Error")]
error: Option<String>,
}
let mut tasks = Vec::new();
for line in String::from_utf8_lossy(&out.stdout).lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let row: TaskRow =
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
let service = row
.name
.as_deref()
.and_then(|n| n.split_once('.').map(|(svc, _)| svc.to_string()))
.unwrap_or_else(|| service_name.to_string());
tasks.push(SwarmTask {
id: row.id,
service,
node: row.node,
desired_state: row.desired_state,
current_state: row.current_state,
error: row.error,
});
}
Ok(tasks)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn state_file_parses() {
let raw = r#"{"services":[{"name":"a","image":null,"mode":null,"replicas":null,"updated_at":null}],"tasks":[]}"#;
let parsed: SwarmStateFile = serde_json::from_str(raw).unwrap();
assert_eq!(parsed.services.len(), 1);
}
}