feat(billing): implement tenant subscription entitlements system (milestones 0-6)
Some checks failed
ci / ui (push) Failing after 28s
ci / rust (push) Failing after 2m40s
images / build-and-push (push) Failing after 19s

This commit is contained in:
2026-03-30 18:41:23 +03:00
parent 5992044b7e
commit 2595e7f1c5
63 changed files with 8448 additions and 321 deletions

View File

@@ -5,22 +5,32 @@ edition = "2024"
publish = ["madapes"]
[dependencies]
async-nats = "0.42.0"
async-trait = "0.1.89"
axum = "0.8.6"
aws-config = { version = "1.8.6", features = ["behavior-version-latest"] }
aws-credential-types = "1.2.6"
aws-sdk-s3 = "1.106.0"
clap = { version = "4.5.48", features = ["derive", "env"] }
futures = "0.3.31"
jsonwebtoken = "9.3.1"
metrics = "0.23.0"
metrics-exporter-prometheus = "0.16.0"
reqwest = { version = "0.12.23", default-features = false, features = ["json", "rustls-tls"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
sha2 = "0.10.9"
hex = "0.4.3"
shared = { path = "../../shared" }
thiserror = "2.0.16"
tokio = { version = "1.45.0", features = ["macros", "net", "process", "rt-multi-thread", "signal", "time"] }
tower-http = { version = "0.6.6", features = ["trace"] }
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
url = "2.5.4"
uuid = { version = "1.18.1", features = ["serde", "v4"] }
[dev-dependencies]
serde_yaml = "0.9.34"
tower = "0.5.2"
urlencoding = "2.1.3"

View File

@@ -1,7 +1,9 @@
use crate::{
AppState, RequestIds,
auth::{Principal, has_permission},
fleet,
config_registry::{ConfigDomain, ConfigRegistryError},
config_schemas::RoutingConfig,
drift, fleet,
job_engine::{JobEngine, StartJobError},
jobs::{Job, JobStatus, JobStep},
placement::{PlacementResponse, ServiceKind},
@@ -15,7 +17,9 @@ use axum::{
routing::{get, post},
};
use serde::Deserialize;
use sha2::Digest;
use std::time::{SystemTime, UNIX_EPOCH};
use url::Url;
use uuid::Uuid;
const HEADER_IDEMPOTENCY_KEY: &str = "idempotency-key";
@@ -25,21 +29,125 @@ pub fn admin_router() -> Router<AppState> {
Router::new()
.route("/whoami", get(whoami))
.route("/platform/info", get(platform_info))
.route("/platform/drift", get(platform_drift))
.route("/fleet/snapshot", get(fleet_snapshot))
.route("/tenants", get(list_tenants))
.route("/placement/{kind}", get(get_placement))
.route("/config", get(list_config))
.route("/config/{domain}", get(get_config))
.route("/config/{domain}/history", get(get_config_history))
.route("/jobs/platform/verify", post(start_platform_verify))
.route("/jobs/config/validate", post(start_config_validate))
.route("/jobs/config/apply", post(start_config_apply))
.route("/jobs/config/rollback", post(start_config_rollback))
.route("/tenants/echo", get(tenant_echo))
.route(
"/tenants/{tenant_id}/billing",
get(crate::billing::get_billing),
)
.route(
"/tenants/{tenant_id}/billing/checkout",
post(crate::billing::checkout),
)
.route(
"/tenants/{tenant_id}/billing/portal",
post(crate::billing::portal),
)
.route("/jobs/echo", post(create_echo_job))
.route("/jobs/{job_id}", get(get_job))
.route("/jobs/{job_id}/cancel", post(cancel_job))
.route("/jobs/tenant/drain", post(start_tenant_drain))
.route("/jobs/tenant/migrate", post(start_tenant_migrate))
.route("/plan/tenant/migrate", post(plan_tenant_migrate))
.route("/plan/config/apply", post(plan_config_apply))
.route("/audit", get(list_audit))
.route("/swarm/services", get(list_swarm_services))
.route("/swarm/services/{name}/tasks", get(list_swarm_tasks))
}
#[derive(Debug, Deserialize)]
struct PlatformVerifyRequest {
reason: String,
}
async fn start_platform_verify(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<PlatformVerifyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_platform_verify(state.clone(), &principal, body.reason, key) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn get_config_history(
State(state): State<AppState>,
Path(domain): Path<String>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::NOT_FOUND.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
let rows = match source.history_bytes(50).await {
Ok(items) => items
.into_iter()
.filter_map(|(rev, bytes)| {
let v = serde_json::from_slice::<serde_json::Value>(&bytes).ok()?;
Some(serde_json::json!({
"revision": rev,
"sha256": sha256_hex(&bytes),
"value": v
}))
})
.collect::<Vec<_>>(),
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
Err(_) => return StatusCode::NOT_IMPLEMENTED.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "domain": domain.as_str(), "items": rows })),
)
.into_response()
}
async fn whoami(Extension(principal): Extension<Principal>) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
@@ -70,6 +178,18 @@ async fn platform_info(Extension(principal): Extension<Principal>) -> impl IntoR
.into_response()
}
async fn platform_drift(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let r = drift::compute(&state).await;
(StatusCode::OK, Json(r)).into_response()
}
async fn fleet_snapshot(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
@@ -109,6 +229,434 @@ async fn get_placement(
(StatusCode::OK, Json(resp)).into_response()
}
async fn list_config(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domains: Vec<&'static str> = [ConfigDomain::Routing, ConfigDomain::Placement]
.into_iter()
.filter(|d| state.config.source(*d).is_some())
.map(|d| d.as_str())
.collect();
(
StatusCode::OK,
Json(serde_json::json!({ "domains": domains })),
)
.into_response()
}
async fn get_config(
State(state): State<AppState>,
Path(domain): Path<String>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::NOT_FOUND.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
let loaded = source.load_bytes().await;
let (bytes, revision) = match loaded {
Ok(x) => x,
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
Err(ConfigRegistryError::Decode(_)) => return StatusCode::BAD_REQUEST.into_response(),
Err(ConfigRegistryError::NotConfigured) => return StatusCode::NOT_FOUND.into_response(),
};
let json_value = match bytes {
Some(ref b) => match serde_json::from_slice::<serde_json::Value>(b) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": format!("invalid json: {e}") })),
)
.into_response();
}
},
None => serde_json::Value::Null,
};
let sha256 = bytes.as_deref().map(sha256_hex);
(
StatusCode::OK,
Json(serde_json::json!({
"domain": domain.as_str(),
"revision": revision,
"sha256": sha256,
"source": source.info(),
"value": json_value,
})),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct ConfigApplyRequest {
domain: String,
expected_revision: Option<u64>,
reason: String,
value: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ConfigValidateRequest {
domain: String,
reason: String,
value: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct ConfigRollbackRequest {
domain: String,
reason: String,
}
fn parse_domain(domain: &str) -> Option<ConfigDomain> {
match domain {
"routing" => Some(ConfigDomain::Routing),
"placement" => Some(ConfigDomain::Placement),
_ => None,
}
}
async fn start_config_validate(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigValidateRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_config_validate(
state.clone(),
&principal,
domain,
body.reason,
body.value,
key,
) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn start_config_apply(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigApplyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_config_apply(
state.clone(),
&principal,
domain,
body.reason,
body.expected_revision,
body.value,
key,
) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
async fn start_config_rollback(
State(state): State<AppState>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigRollbackRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let key = headers
.get(HEADER_IDEMPOTENCY_KEY)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST);
let key = match key {
Ok(k) if !k.is_empty() => k,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(domain) = parse_domain(body.domain.as_str()) else {
return StatusCode::BAD_REQUEST.into_response();
};
let engine = JobEngine::new(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id =
match engine.start_config_rollback(state.clone(), &principal, domain, body.reason, key) {
Ok(id) => id,
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
};
(
StatusCode::OK,
Json(serde_json::json!({ "job_id": job_id })),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct ConfigPlanApplyRequest {
domain: String,
value: serde_json::Value,
}
async fn plan_config_apply(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
Json(body): Json<ConfigPlanApplyRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
let domain = match body.domain.as_str() {
"routing" => ConfigDomain::Routing,
"placement" => ConfigDomain::Placement,
_ => return StatusCode::BAD_REQUEST.into_response(),
};
let Some(source) = state.config.source(domain) else {
return StatusCode::NOT_FOUND.into_response();
};
// Validate proposed config (schema + semantics).
let validate_res: Result<(), String> = match domain {
ConfigDomain::Routing => {
let cfg = match serde_json::from_value::<RoutingConfig>(body.value.clone()) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response();
}
};
validate_routing_semantics(&cfg)
}
ConfigDomain::Placement => {
let cfg =
match serde_json::from_value::<crate::placement::PlacementFile>(body.value.clone())
{
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response();
}
};
validate_placement_semantics(&cfg)
}
};
if let Err(e) = validate_res {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e })),
)
.into_response();
}
let (cur_bytes, cur_rev) = match source.load_bytes().await {
Ok(x) => x,
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
};
let cur_value = cur_bytes
.as_deref()
.and_then(|b| serde_json::from_slice::<serde_json::Value>(b).ok())
.unwrap_or(serde_json::Value::Null);
let before = serde_json::to_string_pretty(&cur_value).unwrap_or_default();
let after = serde_json::to_string_pretty(&body.value).unwrap_or_default();
let changed = cur_value != body.value;
let impacted_services: Vec<&'static str> = match domain {
ConfigDomain::Routing => vec!["gateway"],
ConfigDomain::Placement => vec!["gateway", "control-api"],
};
(
StatusCode::OK,
Json(serde_json::json!({
"domain": domain.as_str(),
"current_revision": cur_rev,
"changed": changed,
"impacted_services": impacted_services,
"diff": {
"before": before,
"after": after,
}
})),
)
.into_response()
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut h = sha2::Sha256::new();
h.update(bytes);
hex::encode(h.finalize())
}
fn validate_routing_semantics(cfg: &RoutingConfig) -> Result<(), String> {
let shard_maps = [
("aggregate_shards", &cfg.aggregate_shards),
("projection_shards", &cfg.projection_shards),
("runner_shards", &cfg.runner_shards),
];
for (name, map) in shard_maps {
for (shard_id, endpoints) in map {
if endpoints.is_empty() {
return Err(format!("{name}[{shard_id}] has no endpoints"));
}
for ep in endpoints {
let u = Url::parse(ep)
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
if u.scheme() != "http" && u.scheme() != "https" {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
));
}
if u.host_str().is_none() {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must include host"
));
}
}
}
}
let placements = [
(
"aggregate_placement",
&cfg.aggregate_placement,
&cfg.aggregate_shards,
),
(
"projection_placement",
&cfg.projection_placement,
&cfg.projection_shards,
),
(
"runner_placement",
&cfg.runner_placement,
&cfg.runner_shards,
),
];
for (pname, pmap, shards) in placements {
for (tenant, shard_id) in pmap {
if shard_id.trim().is_empty() {
return Err(format!("{pname}[{tenant}] shard_id is empty"));
}
if !shards.contains_key(shard_id) {
return Err(format!(
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
));
}
}
}
Ok(())
}
fn validate_placement_semantics(cfg: &crate::placement::PlacementFile) -> Result<(), String> {
let kinds = [
("aggregate_placement", cfg.aggregate_placement.as_ref()),
("projection_placement", cfg.projection_placement.as_ref()),
("runner_placement", cfg.runner_placement.as_ref()),
];
for (kind, k) in kinds {
let Some(k) = k else { continue };
for p in &k.placements {
if p.targets.is_empty() {
return Err(format!("{kind} tenant {} has no targets", p.tenant_id));
}
if p.targets.iter().any(|t| t.trim().is_empty()) {
return Err(format!("{kind} tenant {} has empty target", p.tenant_id));
}
}
}
Ok(())
}
async fn list_tenants(
State(state): State<AppState>,
Extension(principal): Extension<Principal>,
@@ -256,6 +804,7 @@ async fn start_tenant_drain(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_tenant_drain(
state.clone(),
@@ -298,6 +847,7 @@ async fn start_tenant_migrate(
state.jobs.clone(),
state.audit.clone(),
state.tenant_locks.clone(),
state.config_locks.clone(),
);
let job_id = match engine.start_tenant_migrate(
state.clone(),

904
control/api/src/billing.rs Normal file
View File

@@ -0,0 +1,904 @@
use crate::{
AppState,
auth::{Principal, has_permission},
};
use async_trait::async_trait;
use axum::{
Json,
extract::{Extension, Path, State},
http::{HeaderMap, StatusCode},
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{
collections::BTreeMap,
fs,
path::PathBuf,
sync::{Arc, RwLock},
time::SystemTime,
};
use thiserror::Error;
use uuid::Uuid;
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
fn verify_tenant_isolation(headers: &HeaderMap, path_tenant_id: Uuid) -> Result<(), StatusCode> {
let header_tenant_id = headers
.get(HEADER_TENANT_ID)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)
.and_then(|s| Uuid::parse_str(s).map_err(|_| StatusCode::BAD_REQUEST))?;
if header_tenant_id != path_tenant_id {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Plan {
Free,
Pro,
Enterprise,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubscriptionStatus {
Trialing,
Active,
PastDue,
Paused,
Canceled,
Incomplete,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Entitlements {
pub max_deployments: u32,
pub max_runners: u32,
pub s3_docs_enabled: bool,
pub support_tier: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum BillingEvent {
SubscriptionCreated {
tenant_id: Uuid,
event_id: String,
provider_customer_id: String,
provider_subscription_id: String,
status: SubscriptionStatus,
plan: Plan,
current_period_end: String,
ts_ms: u64,
},
SubscriptionUpdated {
tenant_id: Uuid,
event_id: String,
status: SubscriptionStatus,
plan: Plan,
current_period_end: String,
cancel_at_period_end: bool,
ts_ms: u64,
},
SubscriptionDeleted {
tenant_id: Uuid,
event_id: String,
ts_ms: u64,
},
}
impl BillingEvent {
pub fn tenant_id(&self) -> Uuid {
match self {
Self::SubscriptionCreated { tenant_id, .. } => *tenant_id,
Self::SubscriptionUpdated { tenant_id, .. } => *tenant_id,
Self::SubscriptionDeleted { tenant_id, .. } => *tenant_id,
}
}
pub fn event_id(&self) -> &str {
match self {
Self::SubscriptionCreated { event_id, .. } => event_id,
Self::SubscriptionUpdated { event_id, .. } => event_id,
Self::SubscriptionDeleted { event_id, .. } => event_id,
}
}
pub fn ts_ms(&self) -> u64 {
match self {
Self::SubscriptionCreated { ts_ms, .. } => *ts_ms,
Self::SubscriptionUpdated { ts_ms, .. } => *ts_ms,
Self::SubscriptionDeleted { ts_ms, .. } => *ts_ms,
}
}
}
impl Entitlements {
pub fn derive(plan: Option<&Plan>, status: Option<&SubscriptionStatus>) -> Self {
let is_active = matches!(
status,
Some(SubscriptionStatus::Trialing | SubscriptionStatus::Active)
);
if !is_active {
return Self {
max_deployments: 1,
max_runners: 1,
s3_docs_enabled: false,
support_tier: "community".to_string(),
};
}
match plan.unwrap_or(&Plan::Free) {
Plan::Free => Self {
max_deployments: 3,
max_runners: 1,
s3_docs_enabled: false,
support_tier: "community".to_string(),
},
Plan::Pro => Self {
max_deployments: 10,
max_runners: 5,
s3_docs_enabled: true,
support_tier: "standard".to_string(),
},
Plan::Enterprise => Self {
max_deployments: 1000,
max_runners: 50,
s3_docs_enabled: true,
support_tier: "priority".to_string(),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TenantBillingState {
pub provider: String,
pub provider_customer_id: Option<String>,
pub provider_subscription_id: Option<String>,
pub provider_checkout_session_id: Option<String>,
pub status: Option<SubscriptionStatus>,
pub plan: Option<Plan>,
pub current_period_end: Option<String>,
pub cancel_at_period_end: Option<bool>,
pub processed_webhook_event_ids: Vec<String>,
pub updated_at: u64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BillingStateFile {
pub revision: Option<String>,
pub tenants: BTreeMap<Uuid, TenantBillingState>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct BillingResponse {
pub configured: bool,
pub provider: Option<String>,
pub plan: Option<Plan>,
pub status: Option<SubscriptionStatus>,
pub current_period_end: Option<String>,
pub cancel_at_period_end: Option<bool>,
pub entitlements: Entitlements,
}
#[derive(Clone)]
pub struct BillingStore {
inner: Arc<RwLock<Inner>>,
}
struct Inner {
path: PathBuf,
last_modified: Option<SystemTime>,
cached: Option<BillingStateFile>,
}
impl BillingStore {
pub fn new(path: PathBuf) -> Self {
Self {
inner: Arc::new(RwLock::new(Inner {
path,
last_modified: None,
cached: None,
})),
}
}
pub fn get_for_tenant(&self, tenant_id: Uuid) -> BillingResponse {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
if let Some(state) = inner
.cached
.as_ref()
.and_then(|file| file.tenants.get(&tenant_id))
{
return BillingResponse {
configured: true,
provider: Some(state.provider.clone()),
plan: state.plan.clone(),
status: state.status.clone(),
current_period_end: state.current_period_end.clone(),
cancel_at_period_end: state.cancel_at_period_end,
entitlements: Entitlements::derive(state.plan.as_ref(), state.status.as_ref()),
};
}
BillingResponse {
configured: false,
provider: None,
plan: None,
status: None,
current_period_end: None,
cancel_at_period_end: None,
entitlements: Entitlements::derive(None, None),
}
}
pub fn get_all_tenant_ids(&self) -> Vec<Uuid> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
inner
.cached
.as_ref()
.map(|f| f.tenants.keys().cloned().collect())
.unwrap_or_default()
}
pub fn get_subscription_id(&self, tenant_id: Uuid) -> Option<String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
inner
.cached
.as_ref()
.and_then(|f| f.tenants.get(&tenant_id))
.and_then(|s| s.provider_subscription_id.clone())
}
pub fn apply_event(&self, event: BillingEvent) -> Result<(), String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
revision: Some("dev".to_string()),
tenants: BTreeMap::new(),
});
let tenant_id = event.tenant_id();
let event_id = event.event_id().to_string();
let ts_ms = event.ts_ms();
let state = file.tenants.entry(tenant_id).or_insert(TenantBillingState {
provider: "unknown".to_string(), // Will be updated by Created event
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: None,
plan: None,
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 0,
});
// Deduplication
if state.processed_webhook_event_ids.contains(&event_id) {
return Ok(());
}
// Monotonicity check
if state.updated_at > ts_ms {
state.processed_webhook_event_ids.push(event_id);
state.processed_webhook_event_ids.truncate(50);
inner.save(file)?;
return Ok(());
}
match event {
BillingEvent::SubscriptionCreated {
provider_customer_id,
provider_subscription_id,
status,
plan,
current_period_end,
..
} => {
state.provider_customer_id = Some(provider_customer_id);
state.provider_subscription_id = Some(provider_subscription_id);
state.status = Some(status);
state.plan = Some(plan);
state.current_period_end = Some(current_period_end);
}
BillingEvent::SubscriptionUpdated {
status,
plan,
current_period_end,
cancel_at_period_end,
..
} => {
state.status = Some(status);
state.plan = Some(plan);
state.current_period_end = Some(current_period_end);
state.cancel_at_period_end = Some(cancel_at_period_end);
}
BillingEvent::SubscriptionDeleted { .. } => {
state.status = Some(SubscriptionStatus::Canceled);
}
}
state.updated_at = ts_ms;
state.processed_webhook_event_ids.push(event_id);
state.processed_webhook_event_ids.truncate(50);
inner.save(file)?;
Ok(())
}
#[cfg(test)]
pub fn update_tenant_state(
&self,
tenant_id: Uuid,
state: TenantBillingState,
) -> Result<String, String> {
let mut inner = self.inner.write().expect("billing lock poisoned");
inner.reload_if_changed();
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
revision: Some("dev".to_string()),
tenants: BTreeMap::new(),
});
file.tenants.insert(tenant_id, state);
inner.save(file)
}
}
impl Inner {
fn save(&mut self, mut file: BillingStateFile) -> Result<String, String> {
let revision = format!("rev-{}", Uuid::new_v4());
file.revision = Some(revision.clone());
let raw = serde_json::to_string_pretty(&file).map_err(|e| e.to_string())?;
let tmp = self.path.with_extension("json.tmp");
fs::write(&tmp, raw).map_err(|e| e.to_string())?;
fs::rename(&tmp, &self.path).map_err(|e| e.to_string())?;
self.last_modified = None;
self.cached = Some(file);
Ok(revision)
}
fn reload_if_changed(&mut self) {
let meta = fs::metadata(&self.path).ok();
let modified = meta.and_then(|m| m.modified().ok());
if self.cached.is_some() && modified.is_some() && modified == self.last_modified {
return;
}
self.last_modified = modified;
let p = &self.path;
self.cached = fs::read_to_string(p)
.ok()
.and_then(|raw| serde_json::from_str(&raw).ok());
}
}
pub async fn get_billing(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
let resp = state.billing.get_for_tenant(tenant_id);
(StatusCode::OK, Json(resp)).into_response()
}
#[derive(Debug, Deserialize)]
pub struct CheckoutRequest {
pub plan: Plan,
pub return_path: Option<String>,
}
pub async fn checkout(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
Json(body): Json<CheckoutRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
// Check if subscription already exists and is active/trialing
let current = state.billing.get_for_tenant(tenant_id);
if current.configured
&& matches!(
current.status,
Some(SubscriptionStatus::Active | SubscriptionStatus::Trialing)
)
{
return (
StatusCode::CONFLICT,
Json(serde_json::json!({ "error": "tenant already has an active subscription" })),
)
.into_response();
}
// Construct full return URL
// TODO: Validate return_path against ALLOWED_RETURN_ORIGINS if provided
let return_url = body.return_path.unwrap_or_else(|| "/billing".to_string());
match state
.billing_provider
.create_checkout_session(tenant_id, body.plan, return_url)
.await
{
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
Err(e) => {
let err_msg = e.to_string();
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": err_msg })),
)
.into_response()
}
}
}
pub async fn portal(
State(state): State<AppState>,
Path(tenant_id): Path<Uuid>,
headers: HeaderMap,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
return status.into_response();
}
let return_url = "/billing".to_string();
match state
.billing_provider
.create_portal_session(tenant_id, return_url)
.await
{
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
Err(e) => {
let err_msg = e.to_string();
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": err_msg })),
)
.into_response()
}
}
}
pub async fn webhook(
State(state): State<AppState>,
Path(_provider): Path<String>,
headers: HeaderMap,
body: axum::body::Bytes,
) -> impl IntoResponse {
// Note: We don't require auth here as this is a public endpoint called by the provider.
// Security is handled via signature verification in the provider trait.
match state.billing_provider.verify_webhook(&body, &headers).await {
Ok(event) => {
metrics::counter!("billing_webhook_requests_total", "status" => "success").increment(1);
if let Err(e) = state.billing.apply_event(event) {
tracing::error!(error = %e, "failed to apply billing event from webhook");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e })),
)
.into_response();
}
StatusCode::OK.into_response()
}
Err(e) => {
metrics::counter!("billing_webhook_requests_total", "status" => "error").increment(1);
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
.into_response()
}
}
}
pub async fn run_reconciliation_loop(state: AppState) {
let interval_secs = std::env::var("CONTROL_BILLING_RECONCILE_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3600);
tracing::info!(interval_secs, "starting billing reconciliation loop");
loop {
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
tracing::info!("starting billing reconciliation run");
reconcile_once(&state).await;
// Update tenant status gauges
// Note: This is an expensive operation if there are many tenants,
// but for reconciliation it's fine once per hour.
update_billing_gauges(&state);
}
}
pub async fn reconcile_once(state: &AppState) {
let start = std::time::Instant::now();
let tenant_ids = state.billing.get_all_tenant_ids();
let mut success = 0;
let mut error = 0;
let mut skipped = 0;
for tenant_id in tenant_ids {
let sub_id = state.billing.get_subscription_id(tenant_id);
if let Some(subscription_id) = sub_id {
match state
.billing_provider
.fetch_subscription(tenant_id, &subscription_id)
.await
{
Ok(event) => {
if let Err(e) = state.billing.apply_event(event) {
tracing::error!(?tenant_id, error = %e, "failed to apply reconciled billing event");
error += 1;
} else {
success += 1;
}
}
Err(e) => {
tracing::error!(?tenant_id, error = %e, "failed to fetch subscription for reconciliation");
error += 1;
}
}
} else {
skipped += 1;
}
}
let elapsed = start.elapsed();
metrics::counter!("billing_reconciliation_runs_total", "result" => "done").increment(1);
metrics::histogram!("billing_reconciliation_duration_ms").record(elapsed.as_millis() as f64);
tracing::info!(
success,
error,
skipped,
duration_ms = elapsed.as_millis(),
"billing reconciliation run complete"
);
}
fn update_billing_gauges(state: &AppState) {
let tenant_ids = state.billing.get_all_tenant_ids();
let mut counts: BTreeMap<(String, String), u64> = BTreeMap::new();
for tenant_id in tenant_ids {
let resp = state.billing.get_for_tenant(tenant_id);
let plan = match resp.plan {
Some(Plan::Free) => "free",
Some(Plan::Pro) => "pro",
Some(Plan::Enterprise) => "enterprise",
None => "none",
}
.to_string();
let status = match resp.status {
Some(SubscriptionStatus::Active) => "active",
Some(SubscriptionStatus::Trialing) => "trialing",
Some(SubscriptionStatus::PastDue) => "past_due",
Some(SubscriptionStatus::Paused) => "paused",
Some(SubscriptionStatus::Canceled) => "canceled",
Some(SubscriptionStatus::Incomplete) => "incomplete",
None => "none",
}
.to_string();
*counts.entry((plan, status)).or_insert(0) += 1;
}
for ((plan, status), count) in counts {
metrics::gauge!("billing_tenant_status_count", "plan" => plan, "status" => status)
.set(count as f64);
}
}
#[derive(Debug, Error)]
pub enum BillingError {
#[error("provider error: {0}")]
Provider(String),
#[error("invalid configuration: {0}")]
Config(String),
}
#[async_trait]
pub trait BillingProvider: Send + Sync {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
plan: Plan,
return_url: String,
) -> Result<String, BillingError>;
async fn create_portal_session(
&self,
tenant_id: Uuid,
return_url: String,
) -> Result<String, BillingError>;
async fn verify_webhook(
&self,
payload: &[u8],
headers: &HeaderMap,
) -> Result<BillingEvent, BillingError>;
async fn fetch_subscription(
&self,
tenant_id: Uuid,
subscription_id: &str,
) -> Result<BillingEvent, BillingError>;
}
pub struct StripeProvider {
pub secret_key: String,
pub price_pro: String,
pub price_enterprise: String,
}
#[async_trait]
impl BillingProvider for StripeProvider {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
plan: Plan,
_return_url: String,
) -> Result<String, BillingError> {
let _price = match plan {
Plan::Pro => &self.price_pro,
Plan::Enterprise => &self.price_enterprise,
Plan::Free => {
return Err(BillingError::Config(
"Free plan has no checkout".to_string(),
));
}
};
// TODO: Actually call Stripe API
// For now, returning a simulated Stripe checkout URL
Ok(format!(
"https://checkout.stripe.com/pay/cs_test_{}?tenant_id={}",
Uuid::new_v4(),
tenant_id
))
}
async fn create_portal_session(
&self,
tenant_id: Uuid,
_return_url: String,
) -> Result<String, BillingError> {
// TODO: Actually call Stripe API
Ok(format!(
"https://billing.stripe.com/p/session/ps_test_{}?tenant_id={}",
Uuid::new_v4(),
tenant_id
))
}
async fn verify_webhook(
&self,
_payload: &[u8],
_headers: &HeaderMap,
) -> Result<BillingEvent, BillingError> {
// TODO: Implement real Stripe signature verification
Err(BillingError::Provider("Not implemented".to_string()))
}
async fn fetch_subscription(
&self,
_tenant_id: Uuid,
_subscription_id: &str,
) -> Result<BillingEvent, BillingError> {
// TODO: Actually call Stripe API with timeout
// let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()...
Err(BillingError::Provider("Not implemented".to_string()))
}
}
pub struct MockProvider;
#[async_trait]
impl BillingProvider for MockProvider {
async fn create_checkout_session(
&self,
tenant_id: Uuid,
_plan: Plan,
_return_url: String,
) -> Result<String, BillingError> {
Ok(format!("https://mock.stripe.com/checkout/{}", tenant_id))
}
async fn create_portal_session(
&self,
tenant_id: Uuid,
_return_url: String,
) -> Result<String, BillingError> {
Ok(format!("https://mock.stripe.com/portal/{}", tenant_id))
}
async fn verify_webhook(
&self,
payload: &[u8],
_headers: &HeaderMap,
) -> Result<BillingEvent, BillingError> {
// Mock implementation: just parse the payload as a BillingEvent
serde_json::from_slice(payload).map_err(|e| BillingError::Provider(e.to_string()))
}
async fn fetch_subscription(
&self,
tenant_id: Uuid,
_subscription_id: &str,
) -> Result<BillingEvent, BillingError> {
// Mock implementation: return a SubscriptionUpdated event with current state
// In a real mock we might want to store expectations, but for now we just return something plausible.
Ok(BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: format!("reconcile-{}", Uuid::new_v4()),
status: SubscriptionStatus::Active,
plan: Plan::Pro,
current_period_end: "2099-12-31T23:59:59Z".to_string(),
cancel_at_period_end: false,
ts_ms: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_millis() as u64,
})
}
}
impl MockProvider {
pub fn get_checkout_url(tenant: Uuid) -> String {
format!("https://mock.stripe.com/checkout/{}", tenant)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env::temp_dir;
#[test]
fn test_entitlement_derivation() {
let e = Entitlements::derive(Some(&Plan::Free), Some(&SubscriptionStatus::PastDue));
assert_eq!(e.max_deployments, 1);
let e = Entitlements::derive(Some(&Plan::Pro), Some(&SubscriptionStatus::Active));
assert_eq!(e.max_deployments, 10);
assert!(e.s3_docs_enabled);
let e = Entitlements::derive(Some(&Plan::Enterprise), Some(&SubscriptionStatus::Trialing));
assert_eq!(e.max_deployments, 1000);
}
#[test]
fn test_billing_state_roundtrip() {
let mut path = temp_dir();
path.push(format!("billing-{}.json", Uuid::new_v4()));
let store = BillingStore::new(path.clone());
let tenant_id = Uuid::new_v4();
let resp = store.get_for_tenant(tenant_id);
assert!(!resp.configured);
assert_eq!(resp.entitlements.max_deployments, 1);
let state = TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::Active),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 0,
};
store.update_tenant_state(tenant_id, state).unwrap();
let resp2 = store.get_for_tenant(tenant_id);
assert!(resp2.configured);
assert_eq!(resp2.provider.as_deref(), Some("mock"));
assert_eq!(resp2.plan, Some(Plan::Pro));
assert_eq!(resp2.entitlements.max_deployments, 10);
let _ = fs::remove_file(path);
}
#[tokio::test]
async fn test_reconciliation_corrects_state() {
let mut path = temp_dir();
path.push(format!("billing-reconcile-{}.json", Uuid::new_v4()));
let store = BillingStore::new(path.clone());
let tenant_id = Uuid::new_v4();
// 1. Initial state: PastDue
store
.update_tenant_state(
tenant_id,
TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: Some("cus_1".to_string()),
provider_subscription_id: Some("sub_1".to_string()),
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::PastDue),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 100,
},
)
.unwrap();
let state = AppState {
prometheus: crate::get_test_prometheus_handle(),
auth: crate::AuthConfig { hs256_secret: None },
jobs: crate::jobs::JobStore::default(),
audit: crate::AuditStore::default(),
tenant_locks: crate::job_engine::TenantLocks::default(),
config_locks: crate::job_engine::ConfigLocks::default(),
http: reqwest::Client::new(),
placement: crate::placement::PlacementStore::new(temp_dir().join("placement.json")),
billing: store.clone(),
billing_provider: Arc::new(MockProvider),
billing_enforcement_enabled: true,
config: crate::config_registry::ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: crate::swarm::SwarmStore::new(temp_dir().join("swarm.json")),
docs: None,
};
// 2. Run reconciliation. MockProvider returns Active status.
reconcile_once(&state).await;
// 3. Verify state is now Active
let resp = store.get_for_tenant(tenant_id);
assert_eq!(resp.status, Some(SubscriptionStatus::Active));
let _ = fs::remove_file(path);
}
}

View File

@@ -0,0 +1,323 @@
use async_trait::async_trait;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc, time::Duration};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConfigDomain {
Routing,
Placement,
}
impl ConfigDomain {
pub fn as_str(&self) -> &'static str {
match self {
ConfigDomain::Routing => "routing",
ConfigDomain::Placement => "placement",
}
}
}
#[derive(Debug, Error)]
pub enum ConfigRegistryError {
#[error("source error: {0}")]
Source(String),
#[error("decode error: {0}")]
Decode(String),
#[error("domain not configured")]
NotConfigured,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConfigSnapshot<T> {
pub domain: String,
pub revision: u64,
pub value: T,
pub source: ConfigSourceInfo,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum ConfigSourceInfo {
File { path: String },
NatsKv { bucket: String, key: String },
Fixed,
}
#[async_trait]
pub trait ConfigSource: Send + Sync {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError>;
async fn put_bytes(
&self,
expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError>;
async fn history_bytes(&self, limit: usize)
-> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError>;
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
>;
fn info(&self) -> ConfigSourceInfo;
}
#[derive(Clone)]
pub struct FixedSource {
bytes: Arc<Vec<u8>>,
}
impl FixedSource {
pub fn new(bytes: Vec<u8>) -> Self {
Self {
bytes: Arc::new(bytes),
}
}
}
#[async_trait]
impl ConfigSource for FixedSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
Ok((Some(self.bytes.as_ref().clone()), 1))
}
async fn put_bytes(
&self,
_expected_revision: Option<u64>,
_value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"fixed source is read-only".to_string(),
))
}
async fn history_bytes(
&self,
_limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"fixed source has no history".to_string(),
))
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
Ok(Box::pin(futures::stream::empty()))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::Fixed
}
}
#[derive(Clone)]
pub struct FileSource {
path: PathBuf,
}
impl FileSource {
pub fn new(path: PathBuf) -> Self {
Self { path }
}
}
#[async_trait]
impl ConfigSource for FileSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
let raw = tokio::fs::read(&self.path)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok((Some(raw), 0))
}
async fn put_bytes(
&self,
_expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
let tmp = self.path.with_extension("tmp");
tokio::fs::write(&tmp, &value)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
tokio::fs::rename(&tmp, &self.path)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(0)
}
async fn history_bytes(
&self,
_limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
Err(ConfigRegistryError::Source(
"file source has no history".to_string(),
))
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
Ok(Box::pin(futures::stream::empty()))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::File {
path: self.path.to_string_lossy().to_string(),
}
}
}
#[derive(Clone)]
pub struct NatsKvSource {
kv: async_nats::jetstream::kv::Store,
bucket: String,
key: String,
}
impl NatsKvSource {
pub async fn connect(
nats_url: impl Into<String>,
bucket: impl Into<String>,
key: impl Into<String>,
) -> Result<Self, ConfigRegistryError> {
let nats_url = nats_url.into();
let bucket = bucket.into();
let key = key.into();
let client = tokio::time::timeout(Duration::from_secs(2), async_nats::connect(nats_url))
.await
.map_err(|_| ConfigRegistryError::Source("connect timeout".to_string()))?
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
let jetstream = async_nats::jetstream::new(client);
let kv = match jetstream.get_key_value(&bucket).await {
Ok(kv) => kv,
Err(_) => jetstream
.create_key_value(async_nats::jetstream::kv::Config {
bucket: bucket.clone(),
..Default::default()
})
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
};
Ok(Self { kv, bucket, key })
}
}
#[async_trait]
impl ConfigSource for NatsKvSource {
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
let entry = self
.kv
.entry(&self.key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(match entry {
Some(e) => (Some(e.value.to_vec()), e.revision),
None => (None, 0),
})
}
async fn put_bytes(
&self,
expected_revision: Option<u64>,
value: Vec<u8>,
) -> Result<u64, ConfigRegistryError> {
let rev = match expected_revision {
Some(expected) if expected > 0 => self
.kv
.update(&self.key, value.into(), expected)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
_ => self
.kv
.put(&self.key, value.into())
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
};
Ok(rev)
}
async fn history_bytes(
&self,
limit: usize,
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
let mut stream = self
.kv
.history(&self.key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
let mut out = Vec::new();
while let Some(item) = stream.next().await {
let entry = item.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
out.push((entry.revision, entry.value.to_vec()));
if out.len() >= limit {
break;
}
}
Ok(out)
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
ConfigRegistryError,
> {
let key = self.key.clone();
let watch = self
.kv
.watch(&key)
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
Ok(Box::pin(watch.filter_map(|entry| async move {
match entry {
Ok(entry) => match entry.operation {
async_nats::jetstream::kv::Operation::Put => Some(Ok(())),
async_nats::jetstream::kv::Operation::Delete
| async_nats::jetstream::kv::Operation::Purge => None,
},
Err(e) => Some(Err(ConfigRegistryError::Source(e.to_string()))),
}
})))
}
fn info(&self) -> ConfigSourceInfo {
ConfigSourceInfo::NatsKv {
bucket: self.bucket.clone(),
key: self.key.clone(),
}
}
}
#[derive(Clone)]
pub struct ConfigRegistry {
routing: Option<Arc<dyn ConfigSource>>,
placement: Option<Arc<dyn ConfigSource>>,
}
impl ConfigRegistry {
pub fn new(
routing: Option<Arc<dyn ConfigSource>>,
placement: Option<Arc<dyn ConfigSource>>,
) -> Self {
Self { routing, placement }
}
pub fn source(&self, domain: ConfigDomain) -> Option<Arc<dyn ConfigSource>> {
match domain {
ConfigDomain::Routing => self.routing.clone(),
ConfigDomain::Placement => self.placement.clone(),
}
}
}

View File

@@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RoutingConfig {
pub revision: u64,
pub aggregate_placement: HashMap<String, String>,
pub projection_placement: HashMap<String, String>,
pub runner_placement: HashMap<String, String>,
pub aggregate_shards: HashMap<String, Vec<String>>,
pub projection_shards: HashMap<String, Vec<String>>,
pub runner_shards: HashMap<String, Vec<String>>,
}

View File

@@ -0,0 +1,353 @@
use crate::auth::{Principal, has_permission};
use crate::{AppState, RequestIds};
use axum::{
Router,
body::Bytes,
extract::{Extension, Path, Query, State},
http::{HeaderMap, StatusCode, header},
response::IntoResponse,
routing::{get, post, put},
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
pub fn router() -> Router<AppState> {
Router::new()
.route("/tenants/{tenant_id}/docs", get(list_docs))
.route(
"/tenants/{tenant_id}/docs/{doc_type}/{doc_id}/{filename}",
put(upload_doc),
)
.route(
"/tenants/{tenant_id}/docs/object/{*key}",
get(get_doc).delete(delete_doc),
)
.route(
"/tenants/{tenant_id}/docs/presign/upload",
post(presign_upload),
)
.route(
"/tenants/{tenant_id}/docs/presign/download",
post(presign_download),
)
}
fn ensure_tenant_header(headers: &HeaderMap, tenant_id: Uuid) -> Result<(), StatusCode> {
let header_tid = headers
.get(HEADER_TENANT_ID)
.and_then(|v| v.to_str().ok())
.ok_or(StatusCode::BAD_REQUEST)?;
let header_tid = Uuid::parse_str(header_tid).map_err(|_| StatusCode::BAD_REQUEST)?;
if header_tid != tenant_id {
return Err(StatusCode::FORBIDDEN);
}
Ok(())
}
fn ensure_docs_enabled(state: &AppState, tenant_id: Uuid) -> Result<(), StatusCode> {
if !state.billing_enforcement_enabled {
return Ok(());
}
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
if !entitlements.s3_docs_enabled {
return Err(StatusCode::PAYMENT_REQUIRED);
}
Ok(())
}
#[derive(Debug, Deserialize)]
struct ListQuery {
prefix: Option<String>,
}
#[derive(Debug, Serialize)]
struct ListResponse {
objects: Vec<crate::s3_docs::DocObject>,
}
async fn list_docs(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Query(q): Query<ListQuery>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let prefix = q.prefix.unwrap_or_default();
let prefix = prefix.trim();
if prefix.contains("..") {
return StatusCode::BAD_REQUEST.into_response();
}
let base = format!("{}{}", store_prefix(store), tenant_id);
let prefix = if prefix.is_empty() {
format!("{base}/")
} else {
format!("{base}/{prefix}")
};
match store.list_for_tenant(&tenant_id.to_string(), &prefix).await {
Ok(objects) => (StatusCode::OK, axum::Json(ListResponse { objects })).into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
async fn upload_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, doc_type, doc_id, filename)): Path<(Uuid, String, String, String)>,
Extension(principal): Extension<Principal>,
Extension(request_ids): Extension<RequestIds>,
body: Bytes,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let ct = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let key = match store.key_for(&tenant_id.to_string(), &doc_type, &doc_id, &filename) {
Ok(k) => k,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let bytes = body.to_vec();
let hash = crate::s3_docs::DocsStore::content_hash_sha256_hex(&bytes);
if let Err(e) = store
.put_for_tenant(&tenant_id.to_string(), &key, bytes, ct)
.await
{
tracing::warn!(
request_id = %request_ids.request_id,
correlation_id = ?request_ids.correlation_id,
error = %e,
"docs upload failed"
);
return StatusCode::BAD_GATEWAY.into_response();
}
(
StatusCode::OK,
axum::Json(serde_json::json!({
"key": key,
"sha256": hash,
})),
)
.into_response()
}
async fn get_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, key)): Path<(Uuid, String)>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store
.get_bytes_for_tenant(&tenant_id.to_string(), &key)
.await
{
Ok((bytes, ct)) => {
let mut res = axum::response::Response::new(axum::body::Body::from(bytes));
*res.status_mut() = StatusCode::OK;
if let Some(ct) = ct
&& let Ok(v) = axum::http::HeaderValue::from_str(&ct)
{
res.headers_mut().insert(header::CONTENT_TYPE, v);
}
res
}
Err(_) => StatusCode::NOT_FOUND.into_response(),
}
}
async fn delete_doc(
State(state): State<AppState>,
headers: HeaderMap,
Path((tenant_id, key)): Path<(Uuid, String)>,
Extension(principal): Extension<Principal>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store.delete_for_tenant(&tenant_id.to_string(), &key).await {
Ok(_) => StatusCode::NO_CONTENT.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
#[derive(Debug, Deserialize)]
struct PresignUploadRequest {
doc_type: String,
doc_id: Option<String>,
filename: String,
content_type: Option<String>,
}
async fn presign_upload(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Extension(principal): Extension<Principal>,
axum::Json(body): axum::Json<PresignUploadRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:write") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let doc_id = body.doc_id.unwrap_or_else(|| Uuid::new_v4().to_string());
let key = match store.key_for(
&tenant_id.to_string(),
&body.doc_type,
&doc_id,
&body.filename,
) {
Ok(k) => k,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
match store
.presign_put_for_tenant(
&tenant_id.to_string(),
&key,
body.content_type,
std::time::Duration::from_secs(300),
)
.await
{
Ok(url) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"method": "PUT",
"url": url,
"key": key,
})),
)
.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
#[derive(Debug, Deserialize)]
struct PresignDownloadRequest {
key: String,
}
async fn presign_download(
State(state): State<AppState>,
headers: HeaderMap,
Path(tenant_id): Path<Uuid>,
Extension(principal): Extension<Principal>,
axum::Json(body): axum::Json<PresignDownloadRequest>,
) -> impl IntoResponse {
if !has_permission(&principal, "control:read") {
return StatusCode::FORBIDDEN.into_response();
}
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
return s.into_response();
}
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
return s.into_response();
}
let store = match state.docs.as_ref() {
Some(s) => s,
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
};
let base = format!("{}{}", store_prefix(store), tenant_id);
if !body.key.starts_with(&base) {
return StatusCode::FORBIDDEN.into_response();
}
match store
.presign_get_for_tenant(
&tenant_id.to_string(),
&body.key,
std::time::Duration::from_secs(300),
)
.await
{
Ok(url) => (
StatusCode::OK,
axum::Json(serde_json::json!({
"method": "GET",
"url": url,
"key": body.key,
})),
)
.into_response(),
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
}
}
fn store_prefix(store: &crate::s3_docs::DocsStore) -> &str {
store.prefix()
}

127
control/api/src/drift.rs Normal file
View File

@@ -0,0 +1,127 @@
use crate::{AppState, build_info::extract_build_info, fleet, swarm::SwarmService};
use serde::Serialize;
use std::collections::{BTreeMap, BTreeSet};
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DriftKind {
Missing,
Extra,
Unhealthy,
VersionMismatch,
}
#[derive(Debug, Clone, Serialize)]
pub struct DriftItem {
pub kind: DriftKind,
pub service: String,
pub details: serde_json::Value,
}
#[derive(Debug, Clone, Serialize)]
pub struct DriftResponse {
pub summary: BTreeMap<String, u64>,
pub items: Vec<DriftItem>,
}
pub async fn compute(state: &AppState) -> DriftResponse {
let mut items: Vec<DriftItem> = Vec::new();
// Desired service set: what the Control API was configured to observe.
// (In production, this should evolve into "desired stacks + required services".)
let desired: BTreeSet<String> = state
.fleet_services
.iter()
.map(|s| s.name.clone())
.collect();
// Observed service set: what Swarm reports (dev: from file snapshot).
let observed_services: Vec<SwarmService> = state.swarm.list_services();
let observed: BTreeSet<String> = observed_services.iter().map(|s| s.name.clone()).collect();
for missing in desired.difference(&observed) {
items.push(DriftItem {
kind: DriftKind::Missing,
service: missing.clone(),
details: serde_json::json!({ "expected": true }),
});
}
for extra in observed.difference(&desired) {
items.push(DriftItem {
kind: DriftKind::Extra,
service: extra.clone(),
details: serde_json::json!({ "observed": true }),
});
}
// Health drift: based on fleet snapshot.
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
for s in snapshots {
if !s.health_ok || !s.ready_ok {
items.push(DriftItem {
kind: DriftKind::Unhealthy,
service: s.name.clone(),
details: serde_json::json!({
"health_ok": s.health_ok,
"ready_ok": s.ready_ok,
"metrics_ok": s.metrics_ok,
"base_url": s.base_url,
}),
});
}
}
// Version drift: compare build_info between services when present.
// Desired is not yet explicit; for now we flag when multiple versions exist for same service.
let mut versions_by_service: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
for s in snapshots {
if let Ok(metrics) = state
.http
.get(format!("{}/metrics", s.base_url))
.send()
.await
&& let Ok(body) = metrics.text().await
{
for bi in extract_build_info(&body) {
versions_by_service
.entry(bi.service.clone())
.or_default()
.insert(format!("{}@{}", bi.version, bi.git_sha));
}
}
}
for (svc, vs) in versions_by_service {
if vs.len() > 1 {
items.push(DriftItem {
kind: DriftKind::VersionMismatch,
service: svc,
details: serde_json::json!({ "seen": vs.into_iter().collect::<Vec<_>>() }),
});
}
}
fn ord(k: &DriftKind) -> u8 {
match k {
DriftKind::Missing => 0,
DriftKind::Extra => 1,
DriftKind::Unhealthy => 2,
DriftKind::VersionMismatch => 3,
}
}
items.sort_by(|a, b| (ord(&a.kind), &a.service).cmp(&(ord(&b.kind), &b.service)));
let mut summary: BTreeMap<String, u64> = BTreeMap::new();
for item in &items {
let k = match item.kind {
DriftKind::Missing => "missing",
DriftKind::Extra => "extra",
DriftKind::Unhealthy => "unhealthy",
DriftKind::VersionMismatch => "version_mismatch",
};
*summary.entry(k.to_string()).or_insert(0) += 1;
}
DriftResponse { summary, items }
}

View File

@@ -1,14 +1,19 @@
use crate::{
AppState, Principal,
audit::{AuditEvent, AuditStore},
config_registry::{ConfigDomain, ConfigRegistryError},
config_schemas::RoutingConfig,
fleet,
jobs::{Job, JobStatus, JobStep, JobStore},
placement::PlacementFile,
};
use std::{
collections::HashMap,
path::PathBuf,
sync::{Arc, Mutex},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use url::Url;
use uuid::Uuid;
#[derive(Clone, Default)]
@@ -34,20 +39,52 @@ impl TenantLocks {
}
}
#[derive(Clone, Default)]
pub struct ConfigLocks {
inner: Arc<Mutex<HashMap<String, Uuid>>>,
}
impl ConfigLocks {
pub fn try_lock(&self, domain: ConfigDomain, job_id: Uuid) -> bool {
let mut map = self.inner.lock().expect("config locks poisoned");
let k = domain.as_str().to_string();
if map.contains_key(&k) {
return false;
}
map.insert(k, job_id);
true
}
pub fn unlock(&self, domain: ConfigDomain, job_id: Uuid) {
let mut map = self.inner.lock().expect("config locks poisoned");
let k = domain.as_str().to_string();
if map.get(&k).copied() == Some(job_id) {
map.remove(&k);
}
}
}
#[derive(Clone)]
pub struct JobEngine {
pub jobs: JobStore,
pub audit: AuditStore,
pub tenant_locks: TenantLocks,
pub config_locks: ConfigLocks,
pub step_timeout: Duration,
}
impl JobEngine {
pub fn new(jobs: JobStore, audit: AuditStore, tenant_locks: TenantLocks) -> Self {
pub fn new(
jobs: JobStore,
audit: AuditStore,
tenant_locks: TenantLocks,
config_locks: ConfigLocks,
) -> Self {
Self {
jobs,
audit,
tenant_locks,
config_locks,
step_timeout: Duration::from_millis(500),
}
}
@@ -93,7 +130,7 @@ impl JobEngine {
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(state, inserted, Some(tenant_id), RunSpec::Drain)
.run_job(state, inserted, Some(tenant_id), None, RunSpec::Drain)
.await;
});
@@ -152,6 +189,7 @@ impl JobEngine {
state,
inserted,
Some(tenant_id),
None,
RunSpec::Migrate { runner_target },
)
.await;
@@ -160,7 +198,238 @@ impl JobEngine {
Ok(inserted)
}
async fn run_job(&self, state: AppState, job_id: Uuid, tenant_id: Option<Uuid>, spec: RunSpec) {
#[allow(clippy::too_many_arguments)]
pub fn start_config_apply(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
expected_revision: Option<u64>,
value: serde_json::Value,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![
step("preflight"),
step("validate_config"),
step("backup_config"),
step("apply_config"),
step("reload_config"),
step("verify_config"),
],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.apply", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigApply {
domain,
expected_revision,
value,
},
)
.await;
});
Ok(inserted)
}
pub fn start_config_validate(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
value: serde_json::Value,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![step("validate_config")],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.validate", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigValidate { domain, value },
)
.await;
});
Ok(inserted)
}
pub fn start_config_rollback(
&self,
state: AppState,
principal: &Principal,
domain: ConfigDomain,
reason: String,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
if !self.config_locks.try_lock(domain, job_id) {
return Err(StartJobError::TenantLocked);
}
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![
step("rollback_config"),
step("reload_config"),
step("verify_config"),
],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: format!("config.{}.rollback", domain.as_str()),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(
state,
inserted,
None,
Some(domain),
RunSpec::ConfigRollback { domain },
)
.await;
});
Ok(inserted)
}
pub fn start_platform_verify(
&self,
state: AppState,
principal: &Principal,
reason: String,
idempotency_key: &str,
) -> Result<Uuid, StartJobError> {
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
return Ok(existing);
}
let job_id = Uuid::new_v4();
let now = now_ms();
let job = Job {
job_id,
status: JobStatus::Pending,
steps: vec![step("preflight"), step("platform_verify")],
error: None,
created_at_ms: now,
started_at_ms: None,
finished_at_ms: None,
};
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
self.audit.record(AuditEvent {
ts_ms: now,
principal_sub: principal.sub.clone(),
action: "platform.verify".to_string(),
tenant_id: None,
reason,
job_id: Some(inserted),
});
let engine = self.clone();
tokio::spawn(async move {
engine
.run_job(state, inserted, None, None, RunSpec::PlatformVerify)
.await;
});
Ok(inserted)
}
async fn run_job(
&self,
state: AppState,
job_id: Uuid,
tenant_id: Option<Uuid>,
config_domain: Option<ConfigDomain>,
spec: RunSpec,
) {
self.jobs.update(job_id, |j| {
j.status = JobStatus::Running;
j.started_at_ms = Some(now_ms());
@@ -265,6 +534,9 @@ impl JobEngine {
if let Some(tid) = tenant_id {
self.tenant_locks.unlock(tid, job_id);
}
if let Some(domain) = config_domain {
self.config_locks.unlock(domain, job_id);
}
}
}
@@ -276,7 +548,22 @@ pub enum StartJobError {
#[derive(Clone)]
enum RunSpec {
Drain,
Migrate { runner_target: String },
Migrate {
runner_target: String,
},
ConfigValidate {
domain: ConfigDomain,
value: serde_json::Value,
},
ConfigApply {
domain: ConfigDomain,
expected_revision: Option<u64>,
value: serde_json::Value,
},
ConfigRollback {
domain: ConfigDomain,
},
PlatformVerify,
}
fn step(name: &str) -> JobStep {
@@ -316,9 +603,14 @@ async fn run_step(
"update_placement" => match spec {
RunSpec::Migrate { runner_target } => {
let tenant_id = tenant_id.ok_or_else(|| "missing tenant_id".to_string())?;
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
state
.placement
.update_runner_target(tenant_id, runner_target.clone())
.update_runner_target(
tenant_id,
runner_target.clone(),
entitlements.max_runners as usize,
)
.map(|_| ())
}
_ => Ok(()),
@@ -343,6 +635,400 @@ async fn run_step(
}
_ => Ok(()),
},
"validate_config" => match spec {
RunSpec::ConfigValidate { domain, value }
| RunSpec::ConfigApply { domain, value, .. } => match domain {
ConfigDomain::Routing => {
let cfg = serde_json::from_value::<RoutingConfig>(value.clone())
.map_err(|e| format!("invalid routing config: {e}"))?;
validate_routing_semantic(&cfg)?;
Ok(())
}
ConfigDomain::Placement => {
let cfg = serde_json::from_value::<PlacementFile>(value.clone())
.map_err(|e| format!("invalid placement config: {e}"))?;
validate_placement_semantic(state, &cfg)?;
Ok(())
}
},
_ => Ok(()),
},
"backup_config" => match spec {
RunSpec::ConfigApply { domain, .. } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let (cur, _) = source
.load_bytes()
.await
.map_err(|e| format!("failed to load config: {e}"))?;
let cur = cur.unwrap_or_else(|| b"null".to_vec());
let backup_key_value = serde_json::json!({ "backup": serde_json::from_slice::<serde_json::Value>(&cur).unwrap_or(serde_json::Value::Null) });
let bytes =
serde_json::to_vec_pretty(&backup_key_value).map_err(|e| e.to_string())?;
let backup_source = backup_source_for(&source.info(), *domain)
.await
.map_err(|e| format!("failed to build backup source: {e}"))?;
let _ = backup_source
.put_bytes(None, bytes)
.await
.map_err(|e| format!("failed to write backup: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"apply_config" => match spec {
RunSpec::ConfigApply {
domain,
expected_revision,
value,
} => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let bytes =
serde_json::to_vec_pretty(value).map_err(|e| format!("encode error: {e}"))?;
let _ = source
.put_bytes(*expected_revision, bytes)
.await
.map_err(|e| format!("apply failed: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"rollback_config" => match spec {
RunSpec::ConfigRollback { domain } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let backup_source = backup_source_for(&source.info(), *domain)
.await
.map_err(|e| format!("failed to build backup source: {e}"))?;
let (bytes, _) = backup_source
.load_bytes()
.await
.map_err(|e| format!("failed to load backup: {e}"))?;
let Some(bytes) = bytes else {
return Err("no backup available".to_string());
};
let v: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|e| format!("invalid backup json: {e}"))?;
let backup = v.get("backup").cloned().unwrap_or(serde_json::Value::Null);
let next =
serde_json::to_vec_pretty(&backup).map_err(|e| format!("encode error: {e}"))?;
let _ = source
.put_bytes(None, next)
.await
.map_err(|e| format!("rollback failed: {e}"))?;
Ok(())
}
_ => Ok(()),
},
"reload_config" => Ok(()),
"verify_config" => match spec {
RunSpec::ConfigValidate { domain, .. }
| RunSpec::ConfigApply { domain, .. }
| RunSpec::ConfigRollback { domain } => {
let Some(source) = state.config.source(*domain) else {
return Err("config domain not configured".to_string());
};
let (bytes, _) = source
.load_bytes()
.await
.map_err(|e| format!("failed to load config: {e}"))?;
let bytes = bytes.unwrap_or_else(|| b"null".to_vec());
let v: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|e| format!("invalid stored json: {e}"))?;
match domain {
ConfigDomain::Routing => {
let cfg = serde_json::from_value::<RoutingConfig>(v)
.map_err(|e| format!("invalid routing config: {e}"))?;
validate_routing_semantic(&cfg)?;
Ok(())
}
ConfigDomain::Placement => {
let cfg = serde_json::from_value::<PlacementFile>(v)
.map_err(|e| format!("invalid placement config: {e}"))?;
validate_placement_semantic(state, &cfg)?;
Ok(())
}
}
}
_ => Ok(()),
},
"platform_verify" => match spec {
RunSpec::PlatformVerify => {
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
let bad: Vec<_> = snapshots
.into_iter()
.filter(|s| !(s.health_ok && s.ready_ok))
.map(|s| {
format!(
"{} health_ok={} ready_ok={}",
s.name, s.health_ok, s.ready_ok
)
})
.collect();
if !bad.is_empty() {
return Err(format!("platform verify failed: {}", bad.join("; ")));
}
Ok(())
}
_ => Ok(()),
},
_ => Ok(()),
}
}
async fn backup_source_for(
info: &crate::config_registry::ConfigSourceInfo,
domain: ConfigDomain,
) -> Result<Arc<dyn crate::config_registry::ConfigSource>, ConfigRegistryError> {
use crate::config_registry::{ConfigSource, FileSource, NatsKvSource};
match info {
crate::config_registry::ConfigSourceInfo::File { path } => Ok(Arc::new(FileSource::new(
PathBuf::from(path).with_extension(format!("{}.bak.json", domain.as_str())),
))
as Arc<dyn ConfigSource>),
crate::config_registry::ConfigSourceInfo::NatsKv { bucket, key } => {
let nats_url = std::env::var("CONTROL_CONFIG_NATS_URL").map_err(|_| {
ConfigRegistryError::Source("missing CONTROL_CONFIG_NATS_URL".to_string())
})?;
Ok(Arc::new(
NatsKvSource::connect(nats_url, bucket.clone(), format!("{key}.bak"))
.await
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
) as Arc<dyn ConfigSource>)
}
crate::config_registry::ConfigSourceInfo::Fixed => Err(ConfigRegistryError::Source(
"no backups for fixed source".to_string(),
)),
}
}
fn validate_routing_semantic(cfg: &RoutingConfig) -> Result<(), String> {
let shard_maps = [
("aggregate_shards", &cfg.aggregate_shards),
("projection_shards", &cfg.projection_shards),
("runner_shards", &cfg.runner_shards),
];
for (name, map) in shard_maps {
for (shard_id, endpoints) in map {
if endpoints.is_empty() {
return Err(format!("{name}[{shard_id}] has no endpoints"));
}
for ep in endpoints {
let u = Url::parse(ep)
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
if u.scheme() != "http" && u.scheme() != "https" {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
));
}
if u.host_str().is_none() {
return Err(format!(
"{name}[{shard_id}] endpoint {ep:?} must include host"
));
}
}
}
}
// Ensure placement references known shard ids.
let placements = [
(
"aggregate_placement",
&cfg.aggregate_placement,
&cfg.aggregate_shards,
),
(
"projection_placement",
&cfg.projection_placement,
&cfg.projection_shards,
),
(
"runner_placement",
&cfg.runner_placement,
&cfg.runner_shards,
),
];
for (pname, pmap, shards) in placements {
for (tenant, shard_id) in pmap {
if shard_id.trim().is_empty() {
return Err(format!("{pname}[{tenant}] shard_id is empty"));
}
if !shards.contains_key(shard_id) {
return Err(format!(
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
));
}
}
}
Ok(())
}
fn validate_placement_semantic(state: &AppState, cfg: &PlacementFile) -> Result<(), String> {
if !state.billing_enforcement_enabled {
return Ok(());
}
let mut tenant_counts = std::collections::HashMap::new();
let kinds = [
("aggregate_placement", cfg.aggregate_placement.as_ref()),
("projection_placement", cfg.projection_placement.as_ref()),
("runner_placement", cfg.runner_placement.as_ref()),
];
for (kind_name, k) in kinds {
let Some(k) = k else { continue };
for p in &k.placements {
if p.targets.is_empty() {
return Err(format!("{kind_name} tenant {} has no targets", p.tenant_id));
}
if p.targets.iter().any(|t| t.trim().is_empty()) {
return Err(format!(
"{kind_name} tenant {} has empty target",
p.tenant_id
));
}
let entry = tenant_counts.entry(p.tenant_id).or_insert((0, 0)); // (deployments, runners)
if kind_name == "runner_placement" {
entry.1 += p.targets.len();
} else {
entry.0 += p.targets.len();
}
}
}
for (tenant_id, (deployments, runners)) in tenant_counts {
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
if deployments > entitlements.max_deployments as usize {
return Err(format!(
"tenant {} exceeds max_deployments limit ({} > {})",
tenant_id, deployments, entitlements.max_deployments
));
}
if runners > entitlements.max_runners as usize {
return Err(format!(
"tenant {} exceeds max_runners limit ({} > {})",
tenant_id, runners, entitlements.max_runners
));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::billing::{BillingStore, Plan, SubscriptionStatus, TenantBillingState};
use crate::placement::{PlacementFile, PlacementKind, TenantPlacement};
fn mock_state(billing: BillingStore) -> AppState {
let handle = crate::get_test_prometheus_handle();
let root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
AppState {
prometheus: handle,
auth: crate::AuthConfig {
hs256_secret: Some(b"secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: crate::placement::PlacementStore::new(
std::env::temp_dir().join("placement.json"),
),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: crate::config_registry::ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: crate::swarm::SwarmStore::new(root.join("swarm/dev.json")),
docs: None,
}
}
#[test]
fn test_validate_placement_limits() {
let tenant_id = Uuid::new_v4();
let billing_path =
std::env::temp_dir().join(format!("billing-unit-{}.json", Uuid::new_v4()));
let billing = BillingStore::new(billing_path.clone());
let state = mock_state(billing.clone());
// 1. Free plan (default): max_deployments=1, max_runners=1
let cfg = PlacementFile {
revision: Some("v1".to_string()),
aggregate_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["a1".to_string()],
}],
}),
projection_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["p1".to_string()],
}],
}),
runner_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["r1".to_string()],
}],
}),
};
// aggregate(1) + projection(1) = 2 deployments. Limit is 1. Should fail.
let err = validate_placement_semantic(&state, &cfg).unwrap_err();
assert!(err.contains("exceeds max_deployments limit"));
// 2. Reduce to 1 deployment
let cfg2 = PlacementFile {
revision: Some("v2".to_string()),
aggregate_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["a1".to_string()],
}],
}),
projection_placement: None,
runner_placement: Some(PlacementKind {
placements: vec![TenantPlacement {
tenant_id,
targets: vec!["r1".to_string()],
}],
}),
};
validate_placement_semantic(&state, &cfg2).unwrap();
// 3. Upgrade to Pro: max_deployments=10, max_runners=10
billing
.update_tenant_state(
tenant_id,
TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(SubscriptionStatus::Active),
plan: Some(Plan::Pro),
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 100,
},
)
.unwrap();
// Now the first cfg should pass
validate_placement_semantic(&state, &cfg).unwrap();
let _ = std::fs::remove_file(billing_path);
}
}

View File

@@ -1,14 +1,22 @@
mod admin;
mod audit;
mod auth;
pub mod billing;
mod build_info;
pub mod config_registry;
mod config_schemas;
mod deployments;
mod documents;
mod drift;
mod fleet;
mod job_engine;
mod jobs;
mod placement;
pub mod s3_docs;
mod swarm;
use std::sync::Arc;
pub use audit::AuditStore;
pub use auth::{AuthConfig, Principal};
use axum::{
@@ -20,8 +28,10 @@ use axum::{
routing::get,
};
pub use build_info::{BuildInfo, extract_build_info};
pub use config_registry::{ConfigDomain, ConfigRegistry};
pub use deployments::{DeployAnnotationArgs, GrafanaAnnotation, build_grafana_deploy_annotation};
pub use fleet::FleetService;
pub use job_engine::ConfigLocks;
pub use job_engine::TenantLocks;
pub use jobs::JobStore;
use metrics_exporter_prometheus::PrometheusHandle;
@@ -40,10 +50,16 @@ pub struct AppState {
pub jobs: JobStore,
pub audit: AuditStore,
pub tenant_locks: TenantLocks,
pub config_locks: ConfigLocks,
pub http: reqwest::Client,
pub placement: PlacementStore,
pub billing: billing::BillingStore,
pub billing_provider: Arc<dyn billing::BillingProvider>,
pub billing_enforcement_enabled: bool,
pub config: ConfigRegistry,
pub fleet_services: Vec<FleetService>,
pub swarm: SwarmStore,
pub docs: Option<s3_docs::DocsStore>,
}
#[derive(Clone, Debug)]
@@ -93,13 +109,18 @@ pub fn build_app(state: AppState) -> Router {
},
);
let admin =
admin::admin_router().layer(from_fn_with_state(state.clone(), auth::auth_middleware));
let admin = admin::admin_router()
.merge(documents::router())
.layer(from_fn_with_state(state.clone(), auth::auth_middleware));
Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.route("/metrics", get(metrics))
.route(
"/admin/v1/billing/webhooks/{provider}",
axum::routing::post(billing::webhook),
)
.nest("/admin/v1", admin)
.with_state(state)
.layer(trace)
@@ -167,25 +188,46 @@ async fn request_id_middleware(mut req: Request<axum::body::Body>, next: Next) -
res
}
#[cfg(test)]
static TEST_PROMETHEUS_HANDLE: std::sync::OnceLock<PrometheusHandle> = std::sync::OnceLock::new();
#[cfg(test)]
pub(crate) fn get_test_prometheus_handle() -> PrometheusHandle {
TEST_PROMETHEUS_HANDLE
.get_or_init(|| {
metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.unwrap_or_else(|_| {
// This can happen if another test already installed it.
// We might not get the ACTUAL handle to the global recorder here if we don't share it,
// but for tests it's usually fine to have a dummy one if we are not asserting on metrics.
metrics_exporter_prometheus::PrometheusBuilder::new()
.build()
.expect("failed to build prometheus recorder")
.0
.handle()
})
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config_registry::{FileSource, FixedSource};
use crate::jobs::JobStatus;
use axum::{
body::Body,
http::{Request, StatusCode, header},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use metrics_exporter_prometheus::PrometheusBuilder;
use serde::Serialize;
use std::fs;
use std::path::PathBuf;
use std::sync::OnceLock;
use std::sync::Arc;
use tower::ServiceExt;
use uuid::Uuid;
static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
#[derive(Serialize)]
struct TestClaims {
sub: String,
@@ -199,15 +241,10 @@ mod tests {
}
fn test_app_with_fleet(fleet_services: Vec<FleetService>) -> Router {
let handle = HANDLE
.get_or_init(|| {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder")
})
.clone();
let handle = get_test_prometheus_handle();
let placement_path = temp_placement_file();
let root = repo_root();
build_app(AppState {
prometheus: handle,
@@ -217,10 +254,23 @@ mod tests {
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(placement_path),
billing: crate::billing::BillingStore::new(
std::env::temp_dir().join(format!("billing-test-{}.json", Uuid::new_v4())),
),
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services,
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
})
}
@@ -234,14 +284,14 @@ mod tests {
fn temp_placement_file() -> PathBuf {
let root = repo_root();
let src = root.join("placement/dev.json");
let src = root.join("config/placement/dev.json");
let mut dst = std::env::temp_dir();
dst.push(format!(
"cloudlysis-control-placement-{}-{}.json",
std::process::id(),
Uuid::new_v4()
));
let raw = fs::read_to_string(src).expect("missing placement/dev.json");
let raw = fs::read_to_string(src).expect("missing config/placement/dev.json");
fs::write(&dst, raw).expect("failed to write temp placement file");
dst
}
@@ -689,4 +739,467 @@ mod tests {
&serde_json::json!(["preflight", "drain", "update_placement", "reload", "verify"])
);
}
#[tokio::test]
async fn billing_returns_not_configured_by_default() {
let token = make_token(&["control:read"]);
let tenant_id = Uuid::new_v4();
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(false));
assert_eq!(
v.get("entitlements")
.unwrap()
.get("max_deployments")
.unwrap(),
&serde_json::json!(1)
);
}
#[tokio::test]
async fn billing_returns_configured_state() {
let token = make_token(&["control:read"]);
let tenant_id = Uuid::new_v4();
let handle = get_test_prometheus_handle();
let billing_path =
std::env::temp_dir().join(format!("billing-test-cfg-{}.json", Uuid::new_v4()));
let billing = crate::billing::BillingStore::new(billing_path.clone());
billing
.update_tenant_state(
tenant_id,
crate::billing::TenantBillingState {
provider: "stripe".to_string(),
provider_customer_id: Some("cus_123".to_string()),
provider_subscription_id: Some("sub_123".to_string()),
provider_checkout_session_id: None,
status: Some(crate::billing::SubscriptionStatus::Active),
plan: Some(crate::billing::Plan::Pro),
current_period_end: Some("2026-04-30T00:00:00Z".to_string()),
cancel_at_period_end: Some(false),
processed_webhook_event_ids: vec![],
updated_at: 1234567890,
},
)
.unwrap();
let root = repo_root();
let app = build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(temp_placement_file()),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
});
let res = app
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
assert_eq!(
v.get("entitlements")
.unwrap()
.get("max_deployments")
.unwrap(),
&serde_json::json!(10)
);
let _ = std::fs::remove_file(billing_path);
}
#[tokio::test]
async fn checkout_returns_mock_url() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
let body = serde_json::json!({
"plan": "pro",
"return_path": "/custom-return"
});
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(
v.get("url").unwrap(),
&serde_json::json!(format!("https://mock.stripe.com/checkout/{}", tenant_id))
);
}
#[tokio::test]
async fn checkout_fails_if_already_active() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
// Setup app with active subscription
let billing_path =
std::env::temp_dir().join(format!("billing-test-active-{}.json", Uuid::new_v4()));
let billing = crate::billing::BillingStore::new(billing_path.clone());
billing
.update_tenant_state(
tenant_id,
crate::billing::TenantBillingState {
provider: "mock".to_string(),
provider_customer_id: None,
provider_subscription_id: None,
provider_checkout_session_id: None,
status: Some(crate::billing::SubscriptionStatus::Active),
plan: Some(crate::billing::Plan::Pro),
current_period_end: None,
cancel_at_period_end: None,
processed_webhook_event_ids: vec![],
updated_at: 0,
},
)
.unwrap();
let handle = get_test_prometheus_handle();
let root = repo_root();
let app = build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(temp_placement_file()),
billing,
billing_provider: Arc::new(crate::billing::MockProvider),
billing_enforcement_enabled: true,
config: ConfigRegistry::new(
Some(Arc::new(FileSource::new(
root.join("config/routing/dev.json"),
))),
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
),
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
});
let body = serde_json::json!({ "plan": "pro" });
let res = app
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::CONFLICT);
let _ = std::fs::remove_file(billing_path);
}
#[tokio::test]
async fn portal_returns_mock_url() {
let token = make_token(&["control:write"]);
let tenant_id = Uuid::new_v4();
let res = test_app()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/portal"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(
v.get("url").unwrap(),
&serde_json::json!(format!("https://mock.stripe.com/portal/{}", tenant_id))
);
}
#[tokio::test]
async fn webhook_updates_state_idempotently() {
let tenant_id = Uuid::new_v4();
let event_id = "evt_123".to_string();
let app = test_app();
let event = crate::billing::BillingEvent::SubscriptionCreated {
tenant_id,
event_id: event_id.clone(),
provider_customer_id: "cus_123".to_string(),
provider_subscription_id: "sub_123".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Pro,
current_period_end: "2026-04-30T00:00:00Z".to_string(),
ts_ms: 1000,
};
let body = serde_json::to_string(&event).unwrap();
// 1. Send webhook
let res = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.clone()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// 2. Verify state
let token = make_token(&["control:read"]);
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
// 3. Send same webhook again (idempotency)
let res = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn webhook_ignores_stale_events() {
let tenant_id = Uuid::new_v4();
let app = test_app();
// 1. Send recent event (ts=2000)
let event1 = crate::billing::BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: "evt_new".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Enterprise,
current_period_end: "2026-05-30T00:00:00Z".to_string(),
cancel_at_period_end: false,
ts_ms: 2000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.body(Body::from(serde_json::to_string(&event1).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 2. Send stale event (ts=1000)
let event2 = crate::billing::BillingEvent::SubscriptionUpdated {
tenant_id,
event_id: "evt_old".to_string(),
status: crate::billing::SubscriptionStatus::PastDue,
plan: crate::billing::Plan::Pro,
current_period_end: "2026-04-30T00:00:00Z".to_string(),
cancel_at_period_end: false,
ts_ms: 1000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.body(Body::from(serde_json::to_string(&event2).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 3. Verify state is still Enterprise
let token = make_token(&["control:read"]);
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("enterprise"));
}
#[tokio::test]
async fn s3_docs_requires_pro_plan() {
let token = make_token(&["control:read", "control:write"]);
let tenant_id = Uuid::new_v4();
let app = test_app();
// 1. Try to list docs (Free plan by default)
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED);
// 2. Update to Pro plan via webhook
let event = crate::billing::BillingEvent::SubscriptionCreated {
tenant_id,
event_id: "evt_pro".to_string(),
provider_customer_id: "cus_pro".to_string(),
provider_subscription_id: "sub_pro".to_string(),
status: crate::billing::SubscriptionStatus::Active,
plan: crate::billing::Plan::Pro,
current_period_end: "2099-01-01T00:00:00Z".to_string(),
ts_ms: 2000,
};
app.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/billing/webhooks/mock")
.method("POST")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&event).unwrap()))
.unwrap(),
)
.await
.unwrap();
// 3. Try to list docs again (Should fail with 503 if S3 not configured in tests, or 200/502 if it is)
// In test_app(), docs is None by default.
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
// Since docs is None in test_app(), it returns SERVICE_UNAVAILABLE (503) AFTER passing the entitlement check.
// If it was still PAYMENT_REQUIRED, it would return 402.
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
}
}

View File

@@ -1,6 +1,8 @@
use clap::Parser;
use metrics_exporter_prometheus::PrometheusBuilder;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tracing_subscriber::EnvFilter;
#[derive(Parser, Debug)]
@@ -33,16 +35,32 @@ async fn main() {
.build()
.expect("failed to build http client");
let placement_path = std::env::var("CONTROL_PLACEMENT_PATH")
let placement_path: PathBuf = std::env::var("CONTROL_PLACEMENT_PATH")
.ok()
.unwrap_or_else(|| "placement/dev.json".to_string())
.unwrap_or_else(|| "config/placement/dev.json".to_string())
.into();
let swarm_path = std::env::var("CONTROL_SWARM_STATE_PATH")
let billing_path: PathBuf = std::env::var("CONTROL_BILLING_STATE_PATH")
.ok()
.unwrap_or_else(|| "swarm/dev.json".to_string())
.unwrap_or_else(|| "billing/dev.json".to_string())
.into();
let routing_path: PathBuf = std::env::var("CONTROL_ROUTING_PATH")
.ok()
.unwrap_or_else(|| "config/routing/dev.json".to_string())
.into();
let swarm_mode = std::env::var("CONTROL_SWARM_MODE").ok();
let swarm = if swarm_mode.as_deref() == Some("docker") {
api::SwarmStore::new_docker_cli()
} else {
let swarm_path: PathBuf = std::env::var("CONTROL_SWARM_STATE_PATH")
.ok()
.unwrap_or_else(|| "swarm/dev.json".to_string())
.into();
api::SwarmStore::new(swarm_path)
};
let self_url = std::env::var("CONTROL_SELF_URL")
.ok()
.unwrap_or_else(|| "http://127.0.0.1:8080".to_string());
@@ -55,7 +73,70 @@ async fn main() {
fleet_services.extend(parse_fleet_services(&spec));
}
let app = api::build_app(api::AppState {
let docs_cfg =
api::s3_docs::DocsConfig::from_env().expect("missing S3 document storage configuration");
let docs = api::s3_docs::DocsStore::new(docs_cfg)
.await
.expect("failed to initialize S3 document storage client");
let config = {
let routing = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
std::env::var("CONTROL_ROUTING_NATS_URL"),
std::env::var("CONTROL_ROUTING_NATS_BUCKET"),
std::env::var("CONTROL_ROUTING_NATS_KEY"),
) {
Some(Arc::new(
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
.await
.expect("failed to connect to routing config nats kv"),
) as Arc<dyn api::config_registry::ConfigSource>)
} else {
Some(
Arc::new(api::config_registry::FileSource::new(routing_path))
as Arc<dyn api::config_registry::ConfigSource>,
)
};
let placement = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
std::env::var("CONTROL_PLACEMENT_NATS_URL"),
std::env::var("CONTROL_PLACEMENT_NATS_BUCKET"),
std::env::var("CONTROL_PLACEMENT_NATS_KEY"),
) {
Some(Arc::new(
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
.await
.expect("failed to connect to placement config nats kv"),
) as Arc<dyn api::config_registry::ConfigSource>)
} else {
Some(Arc::new(api::config_registry::FileSource::new(
placement_path.clone(),
))
as Arc<dyn api::config_registry::ConfigSource>)
};
api::ConfigRegistry::new(routing, placement)
};
let billing_provider: Arc<dyn api::billing::BillingProvider> =
match std::env::var("CONTROL_BILLING_PROVIDER").as_deref() {
Ok("stripe") => {
let secret_key = std::env::var("CONTROL_STRIPE_SECRET_KEY")
.expect("CONTROL_STRIPE_SECRET_KEY required for stripe provider");
let price_pro = std::env::var("CONTROL_STRIPE_PRICE_ID_PRO")
.expect("CONTROL_STRIPE_PRICE_ID_PRO required for stripe provider");
let price_enterprise = std::env::var("CONTROL_STRIPE_PRICE_ID_ENTERPRISE")
.expect("CONTROL_STRIPE_PRICE_ID_ENTERPRISE required for stripe provider");
Arc::new(api::billing::StripeProvider {
secret_key,
price_pro,
price_enterprise,
})
}
_ => Arc::new(api::billing::MockProvider),
};
let state = api::AppState {
prometheus: recorder,
auth: api::AuthConfig {
hs256_secret: std::env::var("CONTROL_GATEWAY_JWT_HS256_SECRET")
@@ -65,11 +146,25 @@ async fn main() {
jobs: api::JobStore::default(),
audit: api::AuditStore::default(),
tenant_locks: api::TenantLocks::default(),
config_locks: api::ConfigLocks::default(),
http,
placement: api::PlacementStore::new(placement_path),
billing: api::billing::BillingStore::new(billing_path),
billing_provider,
billing_enforcement_enabled: std::env::var("CONTROL_BILLING_ENFORCEMENT_ENABLED")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(false),
config,
fleet_services,
swarm: api::SwarmStore::new(swarm_path),
});
swarm,
docs: Some(docs),
};
// Spawn reconciliation loop
tokio::spawn(api::billing::run_reconciliation_loop(state.clone()));
let app = api::build_app(state);
let listener = tokio::net::TcpListener::bind(args.addr)
.await

View File

@@ -157,6 +157,7 @@ impl PlacementStore {
&self,
tenant_id: Uuid,
runner_target: String,
max_runners: usize,
) -> Result<String, String> {
let mut inner = self.inner.write().expect("placement lock poisoned");
inner.reload_if_changed();
@@ -178,8 +179,17 @@ impl PlacementStore {
.iter_mut()
.find(|p| p.tenant_id == tenant_id)
{
// If already at or above limit, and we are adding a NEW target (not replacing), it would fail.
// But here update_runner_target REPLACES the target list with a single target for now.
// If in the future we want to append, we check targets.len().
if 1 > max_runners {
return Err(format!("exceeds max_runners limit of {}", max_runners));
}
existing.targets = vec![runner_target];
} else {
if 1 > max_runners {
return Err(format!("exceeds max_runners limit of {}", max_runners));
}
runner.placements.push(TenantPlacement {
tenant_id,
targets: vec![runner_target],

508
control/api/src/s3_docs.rs Normal file
View File

@@ -0,0 +1,508 @@
use aws_config::Region;
use aws_credential_types::Credentials;
use aws_sdk_s3::presigning::PresigningConfig;
use aws_sdk_s3::types::BucketCannedAcl;
use aws_sdk_s3::{Client, config::Builder as S3ConfigBuilder};
use sha2::Digest;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct DocsConfig {
pub endpoint: String,
pub public_endpoint: Option<String>,
pub region: String,
pub access_key_id: String,
pub secret_access_key: String,
pub force_path_style: bool,
pub insecure: bool,
pub buckets: Vec<String>,
pub prefix: String,
}
impl DocsConfig {
pub fn from_env() -> Result<Self, String> {
fn get(name: &str) -> Option<String> {
std::env::var(name)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn get_secret(name: &str, file_name: &str) -> Result<Option<String>, String> {
if let Some(path) = get(file_name) {
let raw = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
let v = raw.trim().to_string();
if v.is_empty() {
return Ok(None);
}
return Ok(Some(v));
}
Ok(get(name))
}
let endpoint = get("CONTROL_S3_ENDPOINT")
.or_else(|| get("S3_ENDPOINT"))
.ok_or_else(|| "Missing CONTROL_S3_ENDPOINT".to_string())?;
let public_endpoint =
get("CONTROL_S3_PUBLIC_ENDPOINT").or_else(|| get("S3_PUBLIC_ENDPOINT"));
let region = get("CONTROL_S3_REGION")
.or_else(|| get("S3_REGION"))
.unwrap_or_else(|| "us-east-1".to_string());
let access_key_id =
get_secret("CONTROL_S3_ACCESS_KEY_ID", "CONTROL_S3_ACCESS_KEY_ID_FILE")?
.or_else(|| {
get_secret("S3_ACCESS_KEY_ID", "S3_ACCESS_KEY_ID_FILE")
.ok()
.flatten()
})
.ok_or_else(|| "Missing CONTROL_S3_ACCESS_KEY_ID".to_string())?;
let secret_access_key = get_secret(
"CONTROL_S3_SECRET_ACCESS_KEY",
"CONTROL_S3_SECRET_ACCESS_KEY_FILE",
)?
.or_else(|| {
get_secret("S3_SECRET_ACCESS_KEY", "S3_SECRET_ACCESS_KEY_FILE")
.ok()
.flatten()
})
.ok_or_else(|| "Missing CONTROL_S3_SECRET_ACCESS_KEY".to_string())?;
let force_path_style = get("CONTROL_S3_FORCE_PATH_STYLE")
.or_else(|| get("S3_FORCE_PATH_STYLE"))
.as_deref()
.map(|v| v == "true" || v == "1")
.unwrap_or(true);
let insecure = get("CONTROL_S3_INSECURE")
.or_else(|| get("S3_INSECURE"))
.as_deref()
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let bucket_raw = get("CONTROL_S3_BUCKET_DOCS")
.or_else(|| get("S3_BUCKET_DOCS"))
.ok_or_else(|| "Missing CONTROL_S3_BUCKET_DOCS".to_string())?;
let buckets: Vec<String> = bucket_raw
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if buckets.is_empty() {
return Err("Missing CONTROL_S3_BUCKET_DOCS".to_string());
}
let prefix = get("CONTROL_S3_PREFIX_DOCS")
.or_else(|| get("S3_PREFIX_DOCS"))
.unwrap_or_else(|| "docs/".to_string());
let prefix = if prefix.ends_with('/') {
prefix
} else {
format!("{prefix}/")
};
// SECURITY: `*_INSECURE=true` is intended for local MinIO setups that use plain HTTP.
// We currently do not disable TLS certificate verification for HTTPS endpoints.
if insecure && endpoint.trim_start().starts_with("https://") {
return Err(
"CONTROL_S3_INSECURE=true is not supported with https:// endpoints (TLS verification is not disabled). Use http:// for local MinIO, or set CONTROL_S3_INSECURE=false for production."
.to_string(),
);
}
Ok(Self {
endpoint,
public_endpoint,
region,
access_key_id,
secret_access_key,
force_path_style,
insecure,
buckets,
prefix,
})
}
}
#[derive(Clone)]
pub struct DocsStore {
cfg: DocsConfig,
client: Client,
presign_client: Client,
}
impl DocsStore {
pub async fn new(cfg: DocsConfig) -> Result<Self, String> {
let creds = Credentials::new(
cfg.access_key_id.clone(),
cfg.secret_access_key.clone(),
None,
None,
"static",
);
let shared = aws_config::from_env()
.region(Region::new(cfg.region.clone()))
.credentials_provider(creds.clone())
.endpoint_url(cfg.endpoint.clone())
.load()
.await;
let s3_conf = S3ConfigBuilder::from(&shared)
.force_path_style(cfg.force_path_style)
.build();
let client = Client::from_conf(s3_conf);
let presign_endpoint = cfg
.public_endpoint
.clone()
.unwrap_or_else(|| cfg.endpoint.clone());
let presign_shared = aws_config::from_env()
.region(Region::new(cfg.region.clone()))
.credentials_provider(creds)
.endpoint_url(presign_endpoint)
.load()
.await;
let presign_conf = S3ConfigBuilder::from(&presign_shared)
.force_path_style(cfg.force_path_style)
.build();
let presign_client = Client::from_conf(presign_conf);
Ok(Self {
cfg,
client,
presign_client,
})
}
pub fn key_for(
&self,
tenant_id: &str,
doc_type: &str,
doc_id: &str,
filename: &str,
) -> Result<String, String> {
validate_segment("tenant_id", tenant_id)?;
validate_segment("doc_type", doc_type)?;
validate_segment("doc_id", doc_id)?;
validate_filename(filename)?;
Ok(format!(
"{}{}/{}/{}/{}",
self.cfg.prefix, tenant_id, doc_type, doc_id, filename
))
}
pub fn prefix(&self) -> &str {
self.cfg.prefix.as_str()
}
pub fn buckets(&self) -> &[String] {
self.cfg.buckets.as_slice()
}
fn bucket_for_tenant(&self, tenant_id: &str) -> &str {
// Deterministic sharding across buckets. Note: if the bucket list changes, the mapping changes.
// For production, set the full planned bucket set up-front (e.g. `-0,-1,-2`) to keep mapping stable.
let n = self.cfg.buckets.len();
if n == 1 {
return self.cfg.buckets[0].as_str();
}
let mut hasher = sha2::Sha256::new();
hasher.update(tenant_id.as_bytes());
let digest = hasher.finalize();
let mut b = [0u8; 8];
b.copy_from_slice(&digest[..8]);
let v = u64::from_be_bytes(b);
let idx = (v as usize) % n;
self.cfg.buckets[idx].as_str()
}
pub fn content_hash_sha256_hex(bytes: &[u8]) -> String {
let mut hasher = sha2::Sha256::new();
hasher.update(bytes);
let digest = hasher.finalize();
let mut out = String::with_capacity(digest.len() * 2);
for b in digest {
use std::fmt::Write;
let _ = write!(&mut out, "{:02x}", b);
}
out
}
pub async fn put_for_tenant(
&self,
tenant_id: &str,
key: &str,
bytes: Vec<u8>,
content_type: Option<String>,
) -> Result<(), String> {
let mut req = self
.client
.put_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.body(aws_sdk_s3::primitives::ByteStream::from(bytes));
if let Some(ct) = content_type {
req = req.content_type(ct);
}
req.send().await.map_err(|e| e.to_string())?;
Ok(())
}
pub async fn get_bytes_for_tenant(
&self,
tenant_id: &str,
key: &str,
) -> Result<(Vec<u8>, Option<String>), String> {
let out = self
.client
.get_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.send()
.await
.map_err(|e| e.to_string())?;
let ct = out.content_type().map(|s| s.to_string());
let bytes = out
.body
.collect()
.await
.map_err(|e| e.to_string())?
.into_bytes()
.to_vec();
Ok((bytes, ct))
}
pub async fn delete_for_tenant(&self, tenant_id: &str, key: &str) -> Result<(), String> {
self.client
.delete_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key)
.send()
.await
.map_err(|e| e.to_string())?;
Ok(())
}
pub async fn list_for_tenant(
&self,
tenant_id: &str,
prefix: &str,
) -> Result<Vec<DocObject>, String> {
let out = self
.client
.list_objects_v2()
.bucket(self.bucket_for_tenant(tenant_id))
.prefix(prefix)
.send()
.await
.map_err(|e| e.to_string())?;
let mut items = Vec::new();
for o in out.contents() {
if let Some(key) = o.key() {
items.push(DocObject {
key: key.to_string(),
size: o.size().unwrap_or(0),
last_modified: o.last_modified().map(|d| d.to_string()),
});
}
}
Ok(items)
}
pub async fn ensure_buckets_exist(&self) -> Result<(), String> {
for bucket in &self.cfg.buckets {
let head = self.client.head_bucket().bucket(bucket).send().await;
if head.is_ok() {
continue;
}
self.client
.create_bucket()
.bucket(bucket)
.acl(BucketCannedAcl::Private)
.send()
.await
.map_err(|e| e.to_string())?;
}
Ok(())
}
pub async fn presign_put_for_tenant(
&self,
tenant_id: &str,
key: &str,
content_type: Option<String>,
expires: Duration,
) -> Result<String, String> {
let mut req = self
.presign_client
.put_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key);
if let Some(ct) = content_type {
req = req.content_type(ct);
}
let presigned = req
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
.await
.map_err(|e| e.to_string())?;
Ok(presigned.uri().to_string())
}
pub async fn presign_get_for_tenant(
&self,
tenant_id: &str,
key: &str,
expires: Duration,
) -> Result<String, String> {
let req = self
.presign_client
.get_object()
.bucket(self.bucket_for_tenant(tenant_id))
.key(key);
let presigned = req
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
.await
.map_err(|e| e.to_string())?;
Ok(presigned.uri().to_string())
}
}
#[derive(Clone, Debug, serde::Serialize)]
pub struct DocObject {
pub key: String,
pub size: i64,
pub last_modified: Option<String>,
}
fn validate_segment(name: &str, value: &str) -> Result<(), String> {
if value.is_empty() {
return Err(format!("{name} is required"));
}
if value.len() > 128 {
return Err(format!("{name} too long"));
}
if value.contains('/') || value.contains('\\') {
return Err(format!("{name} contains invalid characters"));
}
if value.contains("..") {
return Err(format!("{name} contains invalid characters"));
}
Ok(())
}
fn validate_filename(value: &str) -> Result<(), String> {
validate_segment("filename", value)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.unwrap()
}
#[test]
fn config_from_env_parses_expected_fields() {
let _guard = env_lock();
unsafe {
std::env::set_var("CONTROL_S3_ENDPOINT", "http://minio:9000");
std::env::set_var("CONTROL_S3_REGION", "us-east-1");
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "minioadmin");
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "minioadmin");
std::env::set_var("CONTROL_S3_BUCKET_DOCS", "cloudlysis-docs");
std::env::set_var("CONTROL_S3_PREFIX_DOCS", "docs/");
std::env::set_var("CONTROL_S3_FORCE_PATH_STYLE", "true");
std::env::set_var("CONTROL_S3_INSECURE", "true");
}
let cfg = DocsConfig::from_env().unwrap();
assert_eq!(cfg.endpoint, "http://minio:9000");
assert_eq!(cfg.buckets, vec!["cloudlysis-docs".to_string()]);
assert_eq!(cfg.prefix, "docs/");
assert!(cfg.force_path_style);
assert!(cfg.insecure);
unsafe {
std::env::remove_var("CONTROL_S3_ENDPOINT");
std::env::remove_var("CONTROL_S3_REGION");
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
std::env::remove_var("CONTROL_S3_PREFIX_DOCS");
std::env::remove_var("CONTROL_S3_FORCE_PATH_STYLE");
std::env::remove_var("CONTROL_S3_INSECURE");
}
}
#[test]
fn config_rejects_insecure_with_https_endpoint() {
let _guard = env_lock();
unsafe {
std::env::set_var("CONTROL_S3_ENDPOINT", "https://s3.example.com");
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "a");
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "b");
std::env::set_var(
"CONTROL_S3_BUCKET_DOCS",
"cloudlysis-docs-0,cloudlysis-docs-1",
);
std::env::set_var("CONTROL_S3_INSECURE", "true");
}
let err = DocsConfig::from_env().unwrap_err();
assert!(
err.contains("CONTROL_S3_INSECURE=true") && err.contains("https://"),
"unexpected error: {err}"
);
unsafe {
std::env::remove_var("CONTROL_S3_ENDPOINT");
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
std::env::remove_var("CONTROL_S3_INSECURE");
}
}
#[tokio::test]
async fn key_scheme_is_stable() {
let cfg = DocsConfig {
endpoint: "http://minio:9000".to_string(),
public_endpoint: None,
region: "us-east-1".to_string(),
access_key_id: "x".to_string(),
secret_access_key: "y".to_string(),
force_path_style: true,
insecure: true,
buckets: vec![
"cloudlysis-docs-0".to_string(),
"cloudlysis-docs-1".to_string(),
],
prefix: "docs/".to_string(),
};
let store = DocsStore::new(cfg).await.unwrap();
let key = store
.key_for("tenant-a", "deployments", "v1", "bundle.tar.gz")
.unwrap();
assert_eq!(key, "docs/tenant-a/deployments/v1/bundle.tar.gz");
}
#[tokio::test]
async fn key_scheme_rejects_invalid_segments() {
let cfg = DocsConfig {
endpoint: "http://minio:9000".to_string(),
public_endpoint: None,
region: "us-east-1".to_string(),
access_key_id: "x".to_string(),
secret_access_key: "y".to_string(),
force_path_style: true,
insecure: true,
buckets: vec!["cloudlysis-docs".to_string()],
prefix: "docs/".to_string(),
};
let store = DocsStore::new(cfg).await.unwrap();
assert!(store.key_for("t/a", "x", "y", "z").is_err());
assert!(store.key_for("t", "x", "../y", "z").is_err());
assert!(store.key_for("t", "x", "y", "a/b").is_err());
}
}

View File

@@ -28,31 +28,49 @@ pub struct SwarmStateFile {
#[derive(Clone)]
pub struct SwarmStore {
path: std::path::PathBuf,
inner: SwarmStoreInner,
}
#[derive(Clone)]
enum SwarmStoreInner {
File { path: std::path::PathBuf },
DockerCli,
}
impl SwarmStore {
pub fn new(path: std::path::PathBuf) -> Self {
Self { path }
Self {
inner: SwarmStoreInner::File { path },
}
}
pub fn new_docker_cli() -> Self {
Self {
inner: SwarmStoreInner::DockerCli,
}
}
pub fn list_services(&self) -> Vec<SwarmService> {
self.load().map(|s| s.services).unwrap_or_default()
match &self.inner {
SwarmStoreInner::File { path } => {
load_state(path).map(|s| s.services).unwrap_or_default()
}
SwarmStoreInner::DockerCli => list_services_docker_cli().unwrap_or_default(),
}
}
pub fn list_tasks(&self, service_name: &str) -> Vec<SwarmTask> {
self.load()
.map(|s| {
s.tasks
.into_iter()
.filter(|t| t.service == service_name)
.collect()
})
.unwrap_or_default()
}
fn load(&self) -> Option<SwarmStateFile> {
load_state(&self.path)
match &self.inner {
SwarmStoreInner::File { path } => load_state(path)
.map(|s| {
s.tasks
.into_iter()
.filter(|t| t.service == service_name)
.collect()
})
.unwrap_or_default(),
SwarmStoreInner::DockerCli => list_tasks_docker_cli(service_name).unwrap_or_default(),
}
}
}
@@ -60,3 +78,120 @@ fn load_state(path: &Path) -> Option<SwarmStateFile> {
let raw = fs::read_to_string(path).ok()?;
serde_json::from_str(&raw).ok()
}
fn list_services_docker_cli() -> Result<Vec<SwarmService>, String> {
let out = std::process::Command::new("docker")
.args(["service", "ls", "--format", "{{json .}}"])
.output()
.map_err(|e| format!("docker exec failed: {e}"))?;
if !out.status.success() {
return Err(format!(
"docker service ls failed: {}",
String::from_utf8_lossy(&out.stderr)
));
}
#[derive(Deserialize)]
struct ServiceRow {
#[serde(rename = "Name")]
name: String,
#[serde(rename = "Image")]
image: Option<String>,
#[serde(rename = "Mode")]
mode: Option<String>,
#[serde(rename = "Replicas")]
replicas: Option<String>,
#[serde(rename = "UpdatedAt")]
updated_at: Option<String>,
}
let mut services = Vec::new();
for line in String::from_utf8_lossy(&out.stdout).lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let row: ServiceRow =
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
services.push(SwarmService {
name: row.name,
image: row.image,
mode: row.mode,
replicas: row.replicas,
updated_at: row.updated_at,
});
}
Ok(services)
}
fn list_tasks_docker_cli(service_name: &str) -> Result<Vec<SwarmTask>, String> {
let out = std::process::Command::new("docker")
.args([
"service",
"ps",
service_name,
"--no-trunc",
"--format",
"{{json .}}",
])
.output()
.map_err(|e| format!("docker exec failed: {e}"))?;
if !out.status.success() {
return Err(format!(
"docker service ps failed: {}",
String::from_utf8_lossy(&out.stderr)
));
}
#[derive(Deserialize)]
struct TaskRow {
#[serde(rename = "ID")]
id: String,
#[serde(rename = "Name")]
name: Option<String>,
#[serde(rename = "Node")]
node: Option<String>,
#[serde(rename = "DesiredState")]
desired_state: Option<String>,
#[serde(rename = "CurrentState")]
current_state: Option<String>,
#[serde(rename = "Error")]
error: Option<String>,
}
let mut tasks = Vec::new();
for line in String::from_utf8_lossy(&out.stdout).lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let row: TaskRow =
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
let service = row
.name
.as_deref()
.and_then(|n| n.split_once('.').map(|(svc, _)| svc.to_string()))
.unwrap_or_else(|| service_name.to_string());
tasks.push(SwarmTask {
id: row.id,
service,
node: row.node,
desired_state: row.desired_state,
current_state: row.current_state,
error: row.error,
});
}
Ok(tasks)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn state_file_parses() {
let raw = r#"{"services":[{"name":"a","image":null,"mode":null,"replicas":null,"updated_at":null}],"tasks":[]}"#;
let parsed: SwarmStateFile = serde_json::from_str(raw).unwrap();
assert_eq!(parsed.services.len(), 1);
}
}

View File

@@ -0,0 +1,174 @@
use api::{
AppState, AuditStore, AuthConfig, ConfigLocks, JobStore, PlacementStore, SwarmStore,
TenantLocks, billing::BillingStore, config_registry::ConfigRegistry,
};
use axum::{
Router,
body::Body,
http::{Request, StatusCode, header},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use metrics_exporter_prometheus::PrometheusBuilder;
use serde::Serialize;
use std::{
path::PathBuf,
sync::{Arc, OnceLock},
};
use tower::ServiceExt;
use uuid::Uuid;
fn prod_enabled() -> bool {
std::env::var("CONTROL_TEST_BILLING_PROD").ok().as_deref() == Some("1")
}
static HANDLE: OnceLock<metrics_exporter_prometheus::PrometheusHandle> = OnceLock::new();
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
#[derive(Serialize)]
struct TestClaims {
sub: String,
session_id: String,
permissions: Vec<String>,
exp: usize,
}
fn make_token(secret: &[u8], perms: &[&str]) -> String {
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 60) as usize;
encode(
&Header::default(),
&TestClaims {
sub: "user_1".to_string(),
session_id: "sess_1".to_string(),
permissions: perms.iter().map(|p| (*p).to_string()).collect(),
exp,
},
&EncodingKey::from_secret(secret),
)
.unwrap()
}
fn test_app() -> Router {
let handle = HANDLE
.get_or_init(|| {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder")
})
.clone();
let provider_type =
std::env::var("CONTROL_BILLING_PROVIDER").unwrap_or_else(|_| "mock".to_string());
let billing_provider: Arc<dyn api::billing::BillingProvider> = match provider_type.as_str() {
"stripe" => Arc::new(api::billing::StripeProvider {
secret_key: std::env::var("CONTROL_STRIPE_SECRET_KEY").unwrap_or_default(),
price_pro: std::env::var("CONTROL_STRIPE_PRICE_ID_PRO").unwrap_or_default(),
price_enterprise: std::env::var("CONTROL_STRIPE_PRICE_ID_ENTERPRISE")
.unwrap_or_default(),
}),
_ => Arc::new(api::billing::MockProvider),
};
api::build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(repo_root().join("config/placement/dev.json")),
billing: BillingStore::new(std::env::temp_dir().join("billing-prod-smoke.json")),
billing_provider,
billing_enforcement_enabled: true,
config: ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
})
}
#[tokio::test]
async fn billing_production_smoke_test() {
if !prod_enabled() {
eprintln!("skipping: set CONTROL_TEST_BILLING_PROD=1 to enable production smoke tests");
return;
}
let app = test_app();
let token = make_token(b"test_secret", &["control:read", "control:write"]);
let tenant_id = Uuid::new_v4();
// 1. Verify GET billing works (empty initially)
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
// 2. Verify Checkout session generation
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::json!({
"plan": "pro",
"return_path": "/billing"
})
.to_string(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(v.get("url").and_then(|u| u.as_str()).is_some());
// 3. Verify Portal session generation (may fail if tenant has no stripe customer id yet, which is expected for fresh tenant)
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/portal"))
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("x-tenant-id", tenant_id.to_string())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
// For smoke test, we just want to see it reached the provider and didn't crash
assert!(res.status() == StatusCode::OK || res.status() == StatusCode::INTERNAL_SERVER_ERROR);
}

View File

@@ -0,0 +1,250 @@
use api::{
AppState, AuditStore, AuthConfig, ConfigLocks, ConfigRegistry, JobStore, PlacementStore,
SwarmStore, TenantLocks, config_registry::NatsKvSource,
};
use axum::{
Router,
body::Body,
http::{Request, StatusCode, header},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use metrics_exporter_prometheus::PrometheusBuilder;
use serde::Serialize;
use std::{path::PathBuf, sync::OnceLock, time::Duration};
use tower::ServiceExt;
use uuid::Uuid;
fn enabled() -> bool {
std::env::var("CONTROL_TEST_NATS").ok().as_deref() == Some("1")
&& std::env::var("CONTROL_TEST_NATS_URL").is_ok()
}
#[derive(Serialize)]
struct TestClaims {
sub: String,
session_id: String,
permissions: Vec<String>,
exp: usize,
}
fn make_token(secret: &[u8], perms: &[&str]) -> String {
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 60) as usize;
encode(
&Header::default(),
&TestClaims {
sub: "user_1".to_string(),
session_id: "sess_1".to_string(),
permissions: perms.iter().map(|p| (*p).to_string()).collect(),
exp,
},
&EncodingKey::from_secret(secret),
)
.unwrap()
}
static HANDLE: OnceLock<metrics_exporter_prometheus::PrometheusHandle> = OnceLock::new();
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
async fn wait_done(app: Router, job_id: Uuid, token: &str) -> serde_json::Value {
let start = tokio::time::Instant::now();
loop {
let res = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/admin/v1/jobs/{job_id}"))
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let job: serde_json::Value = serde_json::from_slice(&body).unwrap();
let status = job
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
if status != "pending" && status != "running" {
return job;
}
if start.elapsed() > Duration::from_secs(2) {
return job;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
}
#[tokio::test]
async fn config_jobs_with_nats_kv_are_env_gated() {
if !enabled() {
eprintln!(
"skipping: set CONTROL_TEST_NATS=1 and CONTROL_TEST_NATS_URL=nats://... to enable nats config tests"
);
return;
}
let nats_url = std::env::var("CONTROL_TEST_NATS_URL").unwrap();
unsafe {
std::env::set_var("CONTROL_CONFIG_NATS_URL", &nats_url);
}
let bucket = format!("cloudlysis-test-config-{}", Uuid::new_v4());
let routing_key = format!("routing/{}", Uuid::new_v4());
let placement_key = format!("placement/{}", Uuid::new_v4());
let routing_src = NatsKvSource::connect(nats_url.clone(), bucket.clone(), routing_key)
.await
.expect("connect routing kv");
let placement_src = NatsKvSource::connect(nats_url.clone(), bucket.clone(), placement_key)
.await
.expect("connect placement kv");
let config = ConfigRegistry::new(
Some(std::sync::Arc::new(routing_src)),
Some(std::sync::Arc::new(placement_src)),
);
let secret = b"test_secret".to_vec();
let token = make_token(&secret, &["control:write", "control:read"]);
let handle = HANDLE
.get_or_init(|| {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder")
})
.clone();
let app = api::build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(secret),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(repo_root().join("config/placement/dev.json")),
billing: api::billing::BillingStore::new(std::env::temp_dir().join("billing-test.json")),
billing_provider: std::sync::Arc::new(api::billing::MockProvider),
billing_enforcement_enabled: false,
config,
fleet_services: vec![],
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
docs: None,
});
let routing_value = serde_json::json!({
"revision": 1,
"aggregate_placement": { "t1": "local" },
"projection_placement": { "t1": "local" },
"runner_placement": { "t1": "local" },
"aggregate_shards": { "local": ["http://aggregate:50051"] },
"projection_shards": { "local": ["http://projection:8080"] },
"runner_shards": { "local": ["http://runner:8080"] }
});
let apply = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/jobs/config/apply")
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("idempotency-key", format!("k-{}", Uuid::new_v4()))
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::json!({
"domain": "routing",
"expected_revision": null,
"reason": "test apply",
"value": routing_value
})
.to_string(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(apply.status(), StatusCode::OK);
let body = axum::body::to_bytes(apply.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let job_id = Uuid::parse_str(v.get("job_id").unwrap().as_str().unwrap()).unwrap();
let job = wait_done(app.clone(), job_id, &token).await;
assert_eq!(
job.get("status").and_then(|v| v.as_str()),
Some("succeeded")
);
let get = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/config/routing")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(get.status(), StatusCode::OK);
let body = axum::body::to_bytes(get.into_body(), 1024 * 1024)
.await
.unwrap();
let got: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(got.get("domain").unwrap().as_str().unwrap(), "routing");
assert!(got.get("revision").unwrap().as_u64().unwrap_or(0) > 0);
let rollback = app
.clone()
.oneshot(
Request::builder()
.uri("/admin/v1/jobs/config/rollback")
.method("POST")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.header("idempotency-key", format!("k-{}", Uuid::new_v4()))
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
serde_json::json!({
"domain": "routing",
"reason": "test rollback"
})
.to_string(),
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(rollback.status(), StatusCode::OK);
let body = axum::body::to_bytes(rollback.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let rb_id = Uuid::parse_str(v.get("job_id").unwrap().as_str().unwrap()).unwrap();
let rb_job = wait_done(app.clone(), rb_id, &token).await;
assert_eq!(
rb_job.get("status").and_then(|v| v.as_str()),
Some("succeeded")
);
}

View File

@@ -0,0 +1,157 @@
use jsonwebtoken::{EncodingKey, Header, encode};
use reqwest::StatusCode;
use serde::Serialize;
use serde_json::json;
use std::time::Duration;
use uuid::Uuid;
#[derive(Serialize)]
struct TestClaims {
sub: String,
session_id: String,
permissions: Vec<String>,
exp: usize,
}
fn make_token(secret: &[u8], perms: &[&str]) -> String {
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 300) as usize;
encode(
&Header::default(),
&TestClaims {
sub: "smoke".to_string(),
session_id: "smoke".to_string(),
permissions: perms.iter().map(|p| (*p).to_string()).collect(),
exp,
},
&EncodingKey::from_secret(secret),
)
.unwrap()
}
#[tokio::test]
async fn control_api_docs_smoke_is_env_gated() {
let enabled = std::env::var("CONTROL_TEST_SMOKE").ok();
if enabled.as_deref() != Some("1") {
eprintln!("skipping: set CONTROL_TEST_SMOKE=1 to enable env smoke tests");
return;
}
let base_url =
std::env::var("CONTROL_TEST_BASE_URL").expect("CONTROL_TEST_BASE_URL is required");
let base_url = base_url.trim_end_matches('/').to_string();
// Either provide a token directly, or provide secret+perms to mint one.
let token = if let Ok(t) = std::env::var("CONTROL_TEST_TOKEN") {
t
} else {
let secret = std::env::var("CONTROL_TEST_JWT_SECRET")
.expect("CONTROL_TEST_TOKEN or CONTROL_TEST_JWT_SECRET is required");
make_token(secret.as_bytes(), &["control:read", "control:write"])
};
let tenant_id = std::env::var("CONTROL_TEST_TENANT_ID")
.ok()
.unwrap_or_else(|| Uuid::new_v4().to_string());
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(15))
.build()
.unwrap();
// Health.
let health = http
.get(format!("{base_url}/health"))
.send()
.await
.expect("health request failed");
assert!(health.status().is_success(), "health not ok");
// Presign upload.
let doc_id = Uuid::new_v4().to_string();
let filename = "smoke.txt";
let presign_up = http
.post(format!(
"{base_url}/admin/v1/tenants/{tenant_id}/docs/presign/upload"
))
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", &tenant_id)
.json(&json!({
"doc_type": "deployments",
"doc_id": doc_id,
"filename": filename,
"content_type": "text/plain",
}))
.send()
.await
.expect("presign upload failed");
assert!(
presign_up.status().is_success(),
"presign upload not ok: {}",
presign_up.status()
);
let up_json: serde_json::Value = presign_up.json().await.unwrap();
let put_url = up_json.get("url").and_then(|v| v.as_str()).unwrap();
let key = up_json
.get("key")
.and_then(|v| v.as_str())
.unwrap()
.to_string();
// PUT bytes to S3 directly.
let payload = b"hello-smoke".to_vec();
let put = http
.put(put_url)
.header("content-type", "text/plain")
.body(payload.clone())
.send()
.await
.expect("s3 put failed");
assert!(put.status().is_success(), "s3 put not ok: {}", put.status());
// List should include key.
let list = http
.get(format!(
"{base_url}/admin/v1/tenants/{tenant_id}/docs?prefix=deployments/"
))
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", &tenant_id)
.send()
.await
.expect("list failed");
assert!(list.status().is_success(), "list not ok");
let list_json: serde_json::Value = list.json().await.unwrap();
let objects = list_json.get("objects").and_then(|v| v.as_array()).unwrap();
assert!(
objects
.iter()
.any(|o| o.get("key").and_then(|k| k.as_str()) == Some(key.as_str())),
"expected list to include presigned upload key"
);
// Presign download and fetch bytes.
let presign_down = http
.post(format!(
"{base_url}/admin/v1/tenants/{tenant_id}/docs/presign/download"
))
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", &tenant_id)
.json(&json!({ "key": key }))
.send()
.await
.expect("presign download failed");
assert!(
presign_down.status().is_success(),
"presign download not ok"
);
let down_json: serde_json::Value = presign_down.json().await.unwrap();
let get_url = down_json.get("url").and_then(|v| v.as_str()).unwrap();
let got = http.get(get_url).send().await.expect("s3 get failed");
assert_eq!(got.status(), StatusCode::OK);
let got_bytes = got.bytes().await.unwrap().to_vec();
assert_eq!(got_bytes, payload);
}

View File

@@ -11,7 +11,7 @@ fn repo_root() -> PathBuf {
#[test]
fn docker_compose_files_parse_and_include_required_services() {
let root = repo_root();
let compose = fs::read_to_string(root.join("observability/docker-compose.yml")).unwrap();
let compose = fs::read_to_string(root.join("docker-compose.yml")).unwrap();
let v: serde_yaml::Value = serde_yaml::from_str(&compose).unwrap();
let services = v
@@ -19,7 +19,15 @@ fn docker_compose_files_parse_and_include_required_services() {
.and_then(|x| x.as_mapping())
.expect("missing services");
for required in ["grafana", "victoria-metrics", "vmagent", "loki", "tempo"] {
// Core + optional observability services are all declared in one compose file.
for required in [
"grafana",
"victoria-metrics",
"vmagent",
"loki",
"tempo",
"mailhog",
] {
assert!(
services.contains_key(serde_yaml::Value::String(required.to_string())),
"missing service {required}"
@@ -28,17 +36,19 @@ fn docker_compose_files_parse_and_include_required_services() {
}
#[tokio::test]
#[ignore]
async fn docker_compose_config_validation_is_gated_and_fast() {
let enabled = std::env::var("CONTROL_TEST_DOCKER").ok();
assert_eq!(enabled.as_deref(), Some("1"));
if enabled.as_deref() != Some("1") {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker compose validation");
return;
}
let root = repo_root();
let compose = root.join("observability/docker-compose.yml");
let compose = root.join("docker-compose.yml");
let cmd = tokio::process::Command::new("docker")
.args(["compose", "-f"])
.arg(compose)
.arg(&compose)
.args(["config"])
.output();
@@ -52,4 +62,22 @@ async fn docker_compose_config_validation_is_gated_and_fast() {
"docker compose config failed: {}",
String::from_utf8_lossy(&out.stderr)
);
// Validate full-stack profile wiring too.
let cmd = tokio::process::Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args(["--profile", "observability", "config"])
.output();
let out = tokio::time::timeout(Duration::from_secs(10), cmd)
.await
.expect("docker compose config (observability profile) timed out")
.expect("failed to run docker compose config (observability profile)");
assert!(
out.status.success(),
"docker compose config (observability profile) failed: {}",
String::from_utf8_lossy(&out.stderr)
);
}

View File

@@ -1,6 +1,9 @@
#[test]
#[ignore]
fn docker_integration_tests_are_gated() {
let enabled = std::env::var("CONTROL_TEST_DOCKER").ok();
if enabled.as_deref() != Some("1") {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker integration tests");
return;
}
assert_eq!(enabled.as_deref(), Some("1"));
}

View File

@@ -0,0 +1,169 @@
use jsonwebtoken::{EncodingKey, Header, encode};
use reqwest::header::{HeaderMap, HeaderValue};
use serde::Serialize;
use std::{path::PathBuf, process::Command, time::Duration};
use uuid::Uuid;
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
fn docker_enabled() -> bool {
std::env::var("CONTROL_TEST_DOCKER")
.ok()
.is_some_and(|v| v.trim() == "1")
}
fn compose_file() -> PathBuf {
repo_root().join("docker-compose.yml")
}
#[derive(Serialize)]
struct TestClaims {
sub: String,
session_id: String,
permissions: Vec<String>,
exp: usize,
}
fn make_token(secret: &[u8], perms: &[&str]) -> String {
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 300) as usize;
encode(
&Header::default(),
&TestClaims {
sub: "user_1".to_string(),
session_id: "sess_1".to_string(),
permissions: perms.iter().map(|p| (*p).to_string()).collect(),
exp,
},
&EncodingKey::from_secret(secret),
)
.unwrap()
}
#[tokio::test]
async fn documents_upload_list_download_roundtrip_via_control_api_compose() {
if !docker_enabled() {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker compose tests");
return;
}
// Must match docker-compose.yml CONTROL_GATEWAY_JWT_HS256_SECRET.
let jwt_secret = b"dev_secret";
let token = make_token(jwt_secret, &["control:read", "control:write"]);
let compose = compose_file();
let up = Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args(["up", "-d", "control-api"])
.status()
.expect("failed to run docker compose up control-api");
assert!(up.success(), "docker compose up control-api failed");
// Wait for control-api to be reachable (port publish is in compose).
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap();
let base = "http://127.0.0.1:38080";
let health_deadline = tokio::time::Instant::now() + Duration::from_secs(30);
loop {
if tokio::time::Instant::now() > health_deadline {
panic!("control-api did not become healthy in time");
}
match http.get(format!("{base}/health")).send().await {
Ok(res) if res.status().is_success() => break,
_ => tokio::time::sleep(Duration::from_millis(250)).await,
}
}
let tenant_id = Uuid::new_v4().to_string();
let doc_type = "deployments";
let doc_id = Uuid::new_v4().to_string();
let filename = "hello.txt";
let bytes = b"hello-docs".to_vec();
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
HeaderValue::from_str(&format!("Bearer {token}")).unwrap(),
);
headers.insert("x-tenant-id", HeaderValue::from_str(&tenant_id).unwrap());
// Upload (proxy endpoint).
let put_url =
format!("{base}/admin/v1/tenants/{tenant_id}/docs/{doc_type}/{doc_id}/{filename}");
let put = http
.put(&put_url)
.headers(headers.clone())
.header("content-type", "text/plain")
.body(bytes.clone())
.send()
.await
.expect("upload request failed");
assert!(
put.status().is_success(),
"upload failed: {}",
put.text().await.unwrap_or_default()
);
let put_json: serde_json::Value = put.json().await.expect("invalid upload json");
let key = put_json
.get("key")
.and_then(|v| v.as_str())
.expect("missing key")
.to_string();
// List should include the key.
let list_url = format!("{base}/admin/v1/tenants/{tenant_id}/docs?prefix={doc_type}/");
let list = http
.get(&list_url)
.headers(headers.clone())
.send()
.await
.expect("list request failed");
assert!(list.status().is_success(), "list failed");
let list_json: serde_json::Value = list.json().await.expect("invalid list json");
let objects = list_json
.get("objects")
.and_then(|v| v.as_array())
.expect("missing objects");
assert!(
objects
.iter()
.any(|o| o.get("key").and_then(|k| k.as_str()) == Some(key.as_str())),
"expected list to include uploaded key"
);
// Download (proxy endpoint) returns same bytes.
let get_url = format!(
"{base}/admin/v1/tenants/{tenant_id}/docs/object/{}",
urlencoding::encode(&key)
);
let got = http
.get(&get_url)
.headers(headers.clone())
.send()
.await
.expect("download request failed");
assert!(got.status().is_success(), "download failed");
let got_bytes = got.bytes().await.expect("download bytes failed").to_vec();
assert_eq!(got_bytes, bytes);
// Best-effort cleanup.
let _ = Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args(["down", "-v"])
.status();
}

View File

@@ -0,0 +1,123 @@
use api::{
AppState, AuditStore, AuthConfig, ConfigLocks, ConfigRegistry, JobStore, PlacementStore,
SwarmStore, TenantLocks,
};
use axum::{
Router,
body::Body,
http::{Request, StatusCode, header},
};
use jsonwebtoken::{EncodingKey, Header, encode};
use metrics_exporter_prometheus::PrometheusBuilder;
use serde::Serialize;
use std::{fs, path::PathBuf, sync::OnceLock};
use tower::ServiceExt;
static HANDLE: OnceLock<metrics_exporter_prometheus::PrometheusHandle> = OnceLock::new();
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
#[derive(Serialize)]
struct TestClaims {
sub: String,
session_id: String,
permissions: Vec<String>,
exp: usize,
}
fn make_token(perms: &[&str]) -> String {
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 60) as usize;
encode(
&Header::default(),
&TestClaims {
sub: "user_1".to_string(),
session_id: "sess_1".to_string(),
permissions: perms.iter().map(|p| (*p).to_string()).collect(),
exp,
},
&EncodingKey::from_secret(b"test_secret"),
)
.unwrap()
}
fn temp_swarm_file(raw: &str) -> PathBuf {
let mut dst = std::env::temp_dir();
dst.push(format!(
"cloudlysis-control-swarm-{}-{}.json",
std::process::id(),
uuid::Uuid::new_v4()
));
fs::write(&dst, raw).expect("failed to write temp swarm file");
dst
}
fn test_app_with_swarm(swarm_path: PathBuf) -> Router {
let handle = HANDLE
.get_or_init(|| {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder")
})
.clone();
api::build_app(AppState {
prometheus: handle,
auth: AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: JobStore::default(),
audit: AuditStore::default(),
tenant_locks: TenantLocks::default(),
config_locks: ConfigLocks::default(),
http: reqwest::Client::new(),
placement: PlacementStore::new(repo_root().join("config/placement/dev.json")),
billing: api::billing::BillingStore::new(
std::env::temp_dir().join("billing-drift-test.json"),
),
billing_provider: std::sync::Arc::new(api::billing::MockProvider),
billing_enforcement_enabled: false,
config: ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: SwarmStore::new(swarm_path),
docs: None,
})
}
#[tokio::test]
async fn drift_marks_extra_services_vs_desired_observation_set() {
let swarm = temp_swarm_file(
r#"{ "services": [{"name":"extra-1","image":null,"mode":null,"replicas":null,"updated_at":null}], "tasks": [] }"#,
);
let app = test_app_with_swarm(swarm);
let token = make_token(&["control:read"]);
let res = app
.oneshot(
Request::builder()
.uri("/admin/v1/platform/drift")
.header(header::AUTHORIZATION, format!("Bearer {token}"))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let items = v.get("items").and_then(|x| x.as_array()).unwrap();
assert!(items.iter().any(|i| {
i.get("kind").and_then(|k| k.as_str()) == Some("extra")
&& i.get("service").and_then(|s| s.as_str()) == Some("extra-1")
}));
}

View File

@@ -0,0 +1,137 @@
#[tokio::test]
async fn platform_drift_docker_test_is_gated() {
use tower::ServiceExt;
let enabled = std::env::var("CONTROL_TEST_DOCKER").ok();
if enabled.as_deref() != Some("1") {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker drift tests");
return;
}
// We only run the "real" drift check when Swarm is available locally.
// If Swarm isn't active, we skip to keep CI/dev machines happy.
let info = std::process::Command::new("docker")
.args(["info", "--format", "{{.Swarm.LocalNodeState}}"])
.output();
let Ok(info) = info else {
eprintln!("skipping: docker not available");
return;
};
if !info.status.success() {
eprintln!("skipping: docker info failed");
return;
}
let state = String::from_utf8_lossy(&info.stdout).trim().to_string();
if state != "active" {
eprintln!("skipping: docker swarm not active (LocalNodeState={state})");
return;
}
// Create a short-lived service so drift can see an "extra" observed service.
let name = format!("cloudlysis-drift-extra-{}", uuid::Uuid::new_v4());
let create = std::process::Command::new("docker")
.args([
"service",
"create",
"--name",
&name,
"--restart-condition",
"none",
"busybox:1.36",
"sh",
"-c",
"sleep 60",
])
.output()
.expect("docker service create");
if !create.status.success() {
eprintln!("skipping: failed to create swarm service (maybe permissions?)");
return;
}
// Ensure cleanup even if assertion fails.
struct Cleanup(String);
impl Drop for Cleanup {
fn drop(&mut self) {
let _ = std::process::Command::new("docker")
.args(["service", "rm", &self.0])
.output();
}
}
let _cleanup = Cleanup(name.clone());
// Now call drift via a minimal in-process app configured for docker-cli swarm observation.
let handle = metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder");
let app = api::build_app(api::AppState {
prometheus: handle,
auth: api::AuthConfig {
hs256_secret: Some(b"test_secret".to_vec()),
},
jobs: api::JobStore::default(),
audit: api::AuditStore::default(),
tenant_locks: api::TenantLocks::default(),
config_locks: api::ConfigLocks::default(),
http: reqwest::Client::new(),
placement: api::PlacementStore::new(
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.unwrap()
.join("config/placement/dev.json"),
),
billing: api::billing::BillingStore::new(
std::env::temp_dir().join("billing-drift-test.json"),
),
billing_provider: std::sync::Arc::new(api::billing::MockProvider),
billing_enforcement_enabled: false,
config: api::ConfigRegistry::new(None, None),
fleet_services: vec![],
swarm: api::SwarmStore::new_docker_cli(),
docs: None,
});
// Auth token (control:read).
let exp = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 60) as usize;
let token = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
&serde_json::json!({
"sub": "user_1",
"session_id": "sess_1",
"permissions": ["control:read"],
"exp": exp
}),
&jsonwebtoken::EncodingKey::from_secret(b"test_secret"),
)
.unwrap();
let res = app
.oneshot(
axum::http::Request::builder()
.uri("/admin/v1/platform/drift")
.header(axum::http::header::AUTHORIZATION, format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(res.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
.await
.unwrap();
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
let items = v.get("items").and_then(|x| x.as_array()).unwrap();
assert!(
items.iter().any(|i| {
i.get("kind").and_then(|k| k.as_str()) == Some("extra")
&& i.get("service").and_then(|s| s.as_str()) == Some(name.as_str())
}),
"expected drift to include extra service {name}, got: {v}"
);
}

View File

@@ -0,0 +1,77 @@
use std::{path::PathBuf, process::Command, time::Duration};
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
fn docker_enabled() -> bool {
std::env::var("CONTROL_TEST_DOCKER")
.ok()
.is_some_and(|v| v.trim() == "1")
}
fn compose_file() -> PathBuf {
repo_root().join("docker-compose.yml")
}
#[test]
fn minio_docs_bucket_exists_and_credentials_work_in_compose_network() {
if !docker_enabled() {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker compose tests");
return;
}
let compose = compose_file();
let up = Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args(["up", "-d", "minio"])
.status()
.expect("failed to run docker compose up minio");
assert!(up.success(), "docker compose up minio failed");
// The `minio-init` service runs `mc` inside the compose network.
let out = Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args([
"run",
"--rm",
"minio-init",
"/bin/sh",
"-lc",
"mc alias set local http://minio:9000 minioadmin minioadmin && mc ls local/cloudlysis-docs-0 && mc ls local/cloudlysis-docs-1 && mc ls local/cloudlysis-docs-2",
])
.output()
.expect("failed to run docker compose run minio-init");
// Best-effort cleanup (keep it short; other docker tests may reuse this env).
let _ = Command::new("docker")
.args(["compose", "-f"])
.arg(&compose)
.args(["down", "-v"])
.status();
assert!(
out.status.success(),
"minio-init bucket check failed: {}",
String::from_utf8_lossy(&out.stderr)
);
// `mc ls` prints at least one line when the bucket exists (even if empty it prints the bucket line).
let stdout = String::from_utf8_lossy(&out.stdout);
assert!(
stdout.contains("cloudlysis-docs-0")
&& stdout.contains("cloudlysis-docs-1")
&& stdout.contains("cloudlysis-docs-2"),
"expected mc ls output to mention bucket: {stdout}"
);
// Avoid tests hanging due to docker flakiness.
std::thread::sleep(Duration::from_millis(10));
}

View File

@@ -8,6 +8,20 @@ fn repo_root() -> PathBuf {
.to_path_buf()
}
#[test]
fn loki_and_tempo_s3_config_variants_are_syntactically_valid() {
let root = repo_root();
for file in [
root.join("observability/loki/config.s3.yml"),
root.join("observability/tempo/config.s3.yml"),
] {
let raw = fs::read_to_string(&file).unwrap_or_else(|e| panic!("{file:?}: {e}"));
let _: serde_yaml::Value =
serde_yaml::from_str(&raw).unwrap_or_else(|e| panic!("{file:?}: {e}"));
}
}
#[test]
fn grafana_provisioning_files_are_syntactically_valid() {
let root = repo_root();

View File

@@ -0,0 +1,218 @@
use reqwest::StatusCode;
use serde_json::json;
use std::{
net::TcpStream,
path::PathBuf,
process::Command,
time::{Duration, Instant},
};
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
fn docker_enabled() -> bool {
std::env::var("CONTROL_TEST_DOCKER")
.ok()
.is_some_and(|v| v.trim() == "1")
}
fn wait_for_tcp(addr: &str, timeout: Duration) -> bool {
let start = Instant::now();
while start.elapsed() < timeout {
if TcpStream::connect_timeout(
&addr.parse().expect("invalid socket addr"),
Duration::from_secs(1),
)
.is_ok()
{
return true;
}
std::thread::sleep(Duration::from_millis(250));
}
false
}
fn mc_ls_bucket(compose: &PathBuf, bucket: &str) -> std::process::Output {
// Run inside compose network so it can reach `minio:9000`.
Command::new("docker")
.args(["compose", "-f"])
.arg(compose)
.args([
"run",
"--rm",
"minio-init",
"/bin/sh",
"-lc",
&format!(
"mc alias set local http://minio:9000 minioadmin minioadmin >/dev/null && mc ls --recursive local/{bucket}"
),
])
.output()
.expect("failed to run mc ls")
}
#[tokio::test]
async fn loki_and_tempo_write_objects_to_minio_in_s3_mode() {
if !docker_enabled() {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker tests");
return;
}
let root = repo_root();
let base = root.join("docker-compose.yml");
let obs = root.join("observability/docker-compose.yml");
let obs_s3 = root.join("observability/docker-compose.s3.yml");
let up = Command::new("docker")
.args(["compose", "-f"])
.arg(&base)
.args(["-f"])
.arg(&obs)
.args(["-f"])
.arg(&obs_s3)
.args(["up", "-d"])
.status()
.expect("failed to run docker compose up");
assert!(up.success(), "docker compose up failed");
let reachable = wait_for_tcp("127.0.0.1:3100", Duration::from_secs(45))
&& wait_for_tcp("127.0.0.1:3200", Duration::from_secs(45))
&& wait_for_tcp("127.0.0.1:9411", Duration::from_secs(45))
&& wait_for_tcp("127.0.0.1:9000", Duration::from_secs(45));
assert!(reachable, "loki/tempo/minio ports not reachable in time");
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.unwrap();
// Push one log line into Loki.
let ts_ns = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos())
.to_string();
let push = http
.post("http://127.0.0.1:3100/loki/api/v1/push")
.json(&json!({
"streams": [{
"stream": { "app": "cloudlysis-test" },
"values": [[ts_ns, "hello from test"]]
}]
}))
.send()
.await
.expect("loki push request failed");
assert!(
push.status() == StatusCode::NO_CONTENT,
"unexpected loki push status: {}",
push.status()
);
// Emit one trace span via Zipkin v2.
let zipkin = http
.post("http://127.0.0.1:9411/api/v2/spans")
.json(&json!([{
"traceId": "463ac35c9f6413ad48485a3953bb6124",
"id": "a2fb4a1d1a96d312",
"name": "test-span",
"timestamp": 1700000000000000u64,
"duration": 1000u64,
"localEndpoint": { "serviceName": "cloudlysis-test" }
}]))
.send()
.await
.expect("zipkin post failed");
assert!(
zipkin.status().is_success(),
"zipkin ingest failed: {}",
zipkin.status()
);
// Query Loki back to ensure the line is retrievable (not just accepted).
// Loki may need a short delay to index.
let loki_deadline = Instant::now() + Duration::from_secs(30);
let mut loki_ok = false;
while Instant::now() < loki_deadline && !loki_ok {
let q = http
.get("http://127.0.0.1:3100/loki/api/v1/query")
.query(&[("query", r#"{app="cloudlysis-test"}"#)])
.send()
.await
.expect("loki query failed");
if q.status().is_success() {
let v: serde_json::Value = q.json().await.expect("invalid loki query json");
// We only need to see any non-empty result.
let has = v
.get("data")
.and_then(|d| d.get("result"))
.and_then(|r| r.as_array())
.is_some_and(|a| !a.is_empty());
if has {
loki_ok = true;
break;
}
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
// Query Tempo back by trace id (Zipkin traceId used above).
let tempo_deadline = Instant::now() + Duration::from_secs(30);
let mut tempo_ok = false;
while Instant::now() < tempo_deadline && !tempo_ok {
let res = http
.get("http://127.0.0.1:3200/api/traces/463ac35c9f6413ad48485a3953bb6124")
.send()
.await
.expect("tempo get trace failed");
if res.status().is_success() {
tempo_ok = true;
break;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
// Poll buckets until at least one object appears.
let deadline = Instant::now() + Duration::from_secs(45);
let mut loki_has_objects = false;
let mut tempo_has_objects = false;
while Instant::now() < deadline && (!loki_has_objects || !tempo_has_objects) {
let loki_out = mc_ls_bucket(&base, "cloudlysis-loki");
if loki_out.status.success() && !loki_out.stdout.is_empty() {
loki_has_objects = true;
}
let tempo_out = mc_ls_bucket(&base, "cloudlysis-tempo");
if tempo_out.status.success() && !tempo_out.stdout.is_empty() {
tempo_has_objects = true;
}
if !loki_has_objects || !tempo_has_objects {
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
let _ = Command::new("docker")
.args(["compose", "-f"])
.arg(&base)
.args(["-f"])
.arg(&obs)
.args(["-f"])
.arg(&obs_s3)
.args(["down", "-v"])
.status();
assert!(loki_has_objects, "expected Loki to write objects to MinIO");
assert!(
tempo_has_objects,
"expected Tempo to write objects to MinIO"
);
assert!(loki_ok, "expected Loki query to return a result");
assert!(tempo_ok, "expected Tempo to return the ingested trace");
}

View File

@@ -30,10 +30,12 @@ fn wait_for_tcp(addr: &str, timeout: Duration) -> bool {
}
#[test]
#[ignore]
fn observability_stack_reaches_healthy_state_fast() {
let enabled = std::env::var("CONTROL_TEST_DOCKER").ok();
assert_eq!(enabled.as_deref(), Some("1"));
if enabled.as_deref() != Some("1") {
eprintln!("skipping: set CONTROL_TEST_DOCKER=1 to enable docker observability smoke test");
return;
}
let root = repo_root();
let compose = root.join("observability/docker-compose.yml");

View File

@@ -0,0 +1,116 @@
use api::s3_docs::{DocsConfig, DocsStore};
use uuid::Uuid;
fn s3_env_ready() -> bool {
// Gate integration tests without requiring `-- --ignored`.
// If CI/local wants these tests to run, it must provide S3 env vars.
let required = [
"CONTROL_S3_ENDPOINT",
"CONTROL_S3_ACCESS_KEY_ID",
"CONTROL_S3_SECRET_ACCESS_KEY",
"CONTROL_S3_BUCKET_DOCS",
];
required
.iter()
.all(|k| std::env::var(k).ok().is_some_and(|v| !v.trim().is_empty()))
}
#[tokio::test]
async fn s3_docs_roundtrip_put_get_list_delete() {
if !s3_env_ready() {
eprintln!("skipping: missing S3 env (see S3_PLAN.md)");
return;
}
let cfg = DocsConfig::from_env().expect("missing S3 env (see S3_PLAN.md)");
let store = DocsStore::new(cfg)
.await
.expect("failed to init docs store");
let tenant_id = Uuid::new_v4().to_string();
let doc_type = "test";
let doc_id = Uuid::new_v4().to_string();
let filename = "hello.txt";
let key = store
.key_for(&tenant_id, doc_type, &doc_id, filename)
.expect("invalid key");
store
.put_for_tenant(
&tenant_id,
&key,
b"hello".to_vec(),
Some("text/plain".to_string()),
)
.await
.expect("put failed");
let (bytes, _ct) = store
.get_bytes_for_tenant(&tenant_id, &key)
.await
.expect("get failed");
assert_eq!(bytes, b"hello");
let prefix = format!("{}{}", store.prefix(), tenant_id);
let objects = store
.list_for_tenant(&tenant_id, &format!("{prefix}/"))
.await
.expect("list failed");
assert!(objects.iter().any(|o| o.key == key));
store
.delete_for_tenant(&tenant_id, &key)
.await
.expect("delete failed");
}
#[tokio::test]
async fn s3_docs_tenant_prefix_isolation() {
if !s3_env_ready() {
eprintln!("skipping: missing S3 env (see S3_PLAN.md)");
return;
}
let cfg = DocsConfig::from_env().expect("missing S3 env (see S3_PLAN.md)");
let store = DocsStore::new(cfg)
.await
.expect("failed to init docs store");
let tenant_a = Uuid::new_v4().to_string();
let tenant_b = Uuid::new_v4().to_string();
let doc_type = "test";
let doc_id = Uuid::new_v4().to_string();
let filename = "hello.txt";
let key_a = store
.key_for(&tenant_a, doc_type, &doc_id, filename)
.expect("invalid key");
store
.put_for_tenant(
&tenant_a,
&key_a,
b"hello-a".to_vec(),
Some("text/plain".to_string()),
)
.await
.expect("put failed");
let prefix_a = format!("{}{tenant_a}/", store.prefix());
let prefix_b = format!("{}{tenant_b}/", store.prefix());
let objects_a = store
.list_for_tenant(&tenant_a, &prefix_a)
.await
.expect("list a failed");
let objects_b = store
.list_for_tenant(&tenant_b, &prefix_b)
.await
.expect("list b failed");
assert!(objects_a.iter().any(|o| o.key == key_a));
assert!(!objects_b.iter().any(|o| o.key == key_a));
store
.delete_for_tenant(&tenant_a, &key_a)
.await
.expect("delete failed");
}

View File

@@ -0,0 +1,36 @@
use std::{path::PathBuf, process::Command};
fn repo_root() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("api crate should live under repo root")
.to_path_buf()
}
fn is_enabled() -> bool {
std::env::var("CONTROL_TEST_AWSCLI")
.ok()
.is_some_and(|v| v.trim() == "1")
}
#[test]
fn s3_docs_permissions_can_be_verified_with_aws_cli() {
if !is_enabled() {
eprintln!("skipping: set CONTROL_TEST_AWSCLI=1 to enable aws-cli S3 permission checks");
return;
}
let script = repo_root().join("docker/scripts/s3_verify_docs.sh");
let out = Command::new("sh")
.arg(script)
.output()
.expect("failed to run s3_verify_docs.sh (requires aws cli and S3_* env)");
assert!(
out.status.success(),
"s3 verify script failed: {}\n{}",
String::from_utf8_lossy(&out.stdout),
String::from_utf8_lossy(&out.stderr)
);
}

View File

@@ -13,6 +13,7 @@ fn stack_files_parse_as_yaml() {
let root = repo_root();
for file in [
root.join("swarm/stacks/control-plane.yml"),
root.join("swarm/stacks/control-plane-prod.yml"),
root.join("swarm/stacks/observability.yml"),
] {
let raw = fs::read_to_string(&file).unwrap();
@@ -38,3 +39,36 @@ fn control_plane_stack_has_required_services() {
);
}
}
#[test]
fn control_plane_prod_stack_has_control_api_and_external_s3_secrets() {
let root = repo_root();
let raw = fs::read_to_string(root.join("swarm/stacks/control-plane-prod.yml")).unwrap();
let v: serde_yaml::Value = serde_yaml::from_str(&raw).unwrap();
let services = v
.get("services")
.and_then(|x| x.as_mapping())
.expect("missing services");
assert!(services.contains_key(serde_yaml::Value::String("control-api".to_string())));
assert!(services.contains_key(serde_yaml::Value::String("control-ui".to_string())));
assert!(
!services.contains_key(serde_yaml::Value::String("minio".to_string())),
"prod stack must not bundle MinIO"
);
let secrets = v
.get("secrets")
.and_then(|x| x.as_mapping())
.expect("missing secrets");
for name in ["control_s3_access_key_id", "control_s3_secret_access_key"] {
let entry = secrets
.get(serde_yaml::Value::String(name.to_string()))
.unwrap_or_else(|| panic!("missing secret {name}"));
let external = entry
.get(serde_yaml::Value::String("external".to_string()))
.and_then(|x| x.as_bool())
.unwrap_or(false);
assert!(external, "secret {name} must be external: true");
}
}