feat(billing): implement tenant subscription entitlements system (milestones 0-6)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use crate::{
|
||||
AppState, RequestIds,
|
||||
auth::{Principal, has_permission},
|
||||
fleet,
|
||||
config_registry::{ConfigDomain, ConfigRegistryError},
|
||||
config_schemas::RoutingConfig,
|
||||
drift, fleet,
|
||||
job_engine::{JobEngine, StartJobError},
|
||||
jobs::{Job, JobStatus, JobStep},
|
||||
placement::{PlacementResponse, ServiceKind},
|
||||
@@ -15,7 +17,9 @@ use axum::{
|
||||
routing::{get, post},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sha2::Digest;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
const HEADER_IDEMPOTENCY_KEY: &str = "idempotency-key";
|
||||
@@ -25,21 +29,125 @@ pub fn admin_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/whoami", get(whoami))
|
||||
.route("/platform/info", get(platform_info))
|
||||
.route("/platform/drift", get(platform_drift))
|
||||
.route("/fleet/snapshot", get(fleet_snapshot))
|
||||
.route("/tenants", get(list_tenants))
|
||||
.route("/placement/{kind}", get(get_placement))
|
||||
.route("/config", get(list_config))
|
||||
.route("/config/{domain}", get(get_config))
|
||||
.route("/config/{domain}/history", get(get_config_history))
|
||||
.route("/jobs/platform/verify", post(start_platform_verify))
|
||||
.route("/jobs/config/validate", post(start_config_validate))
|
||||
.route("/jobs/config/apply", post(start_config_apply))
|
||||
.route("/jobs/config/rollback", post(start_config_rollback))
|
||||
.route("/tenants/echo", get(tenant_echo))
|
||||
.route(
|
||||
"/tenants/{tenant_id}/billing",
|
||||
get(crate::billing::get_billing),
|
||||
)
|
||||
.route(
|
||||
"/tenants/{tenant_id}/billing/checkout",
|
||||
post(crate::billing::checkout),
|
||||
)
|
||||
.route(
|
||||
"/tenants/{tenant_id}/billing/portal",
|
||||
post(crate::billing::portal),
|
||||
)
|
||||
.route("/jobs/echo", post(create_echo_job))
|
||||
.route("/jobs/{job_id}", get(get_job))
|
||||
.route("/jobs/{job_id}/cancel", post(cancel_job))
|
||||
.route("/jobs/tenant/drain", post(start_tenant_drain))
|
||||
.route("/jobs/tenant/migrate", post(start_tenant_migrate))
|
||||
.route("/plan/tenant/migrate", post(plan_tenant_migrate))
|
||||
.route("/plan/config/apply", post(plan_config_apply))
|
||||
.route("/audit", get(list_audit))
|
||||
.route("/swarm/services", get(list_swarm_services))
|
||||
.route("/swarm/services/{name}/tasks", get(list_swarm_tasks))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PlatformVerifyRequest {
|
||||
reason: String,
|
||||
}
|
||||
|
||||
async fn start_platform_verify(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<PlatformVerifyRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let key = headers
|
||||
.get(HEADER_IDEMPOTENCY_KEY)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST);
|
||||
let key = match key {
|
||||
Ok(k) if !k.is_empty() => k,
|
||||
_ => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let engine = JobEngine::new(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id = match engine.start_platform_verify(state.clone(), &principal, body.reason, key) {
|
||||
Ok(id) => id,
|
||||
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "job_id": job_id })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_config_history(
|
||||
State(state): State<AppState>,
|
||||
Path(domain): Path<String>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let domain = match domain.as_str() {
|
||||
"routing" => ConfigDomain::Routing,
|
||||
"placement" => ConfigDomain::Placement,
|
||||
_ => return StatusCode::NOT_FOUND.into_response(),
|
||||
};
|
||||
let Some(source) = state.config.source(domain) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
|
||||
let rows = match source.history_bytes(50).await {
|
||||
Ok(items) => items
|
||||
.into_iter()
|
||||
.filter_map(|(rev, bytes)| {
|
||||
let v = serde_json::from_slice::<serde_json::Value>(&bytes).ok()?;
|
||||
Some(serde_json::json!({
|
||||
"revision": rev,
|
||||
"sha256": sha256_hex(&bytes),
|
||||
"value": v
|
||||
}))
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
|
||||
Err(_) => return StatusCode::NOT_IMPLEMENTED.into_response(),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "domain": domain.as_str(), "items": rows })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn whoami(Extension(principal): Extension<Principal>) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
@@ -70,6 +178,18 @@ async fn platform_info(Extension(principal): Extension<Principal>) -> impl IntoR
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn platform_drift(
|
||||
State(state): State<AppState>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let r = drift::compute(&state).await;
|
||||
(StatusCode::OK, Json(r)).into_response()
|
||||
}
|
||||
|
||||
async fn fleet_snapshot(
|
||||
State(state): State<AppState>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
@@ -109,6 +229,434 @@ async fn get_placement(
|
||||
(StatusCode::OK, Json(resp)).into_response()
|
||||
}
|
||||
|
||||
async fn list_config(
|
||||
State(state): State<AppState>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let domains: Vec<&'static str> = [ConfigDomain::Routing, ConfigDomain::Placement]
|
||||
.into_iter()
|
||||
.filter(|d| state.config.source(*d).is_some())
|
||||
.map(|d| d.as_str())
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "domains": domains })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_config(
|
||||
State(state): State<AppState>,
|
||||
Path(domain): Path<String>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let domain = match domain.as_str() {
|
||||
"routing" => ConfigDomain::Routing,
|
||||
"placement" => ConfigDomain::Placement,
|
||||
_ => return StatusCode::NOT_FOUND.into_response(),
|
||||
};
|
||||
|
||||
let Some(source) = state.config.source(domain) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
|
||||
let loaded = source.load_bytes().await;
|
||||
let (bytes, revision) = match loaded {
|
||||
Ok(x) => x,
|
||||
Err(ConfigRegistryError::Source(_)) => return StatusCode::BAD_GATEWAY.into_response(),
|
||||
Err(ConfigRegistryError::Decode(_)) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
Err(ConfigRegistryError::NotConfigured) => return StatusCode::NOT_FOUND.into_response(),
|
||||
};
|
||||
|
||||
let json_value = match bytes {
|
||||
Some(ref b) => match serde_json::from_slice::<serde_json::Value>(b) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": format!("invalid json: {e}") })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
},
|
||||
None => serde_json::Value::Null,
|
||||
};
|
||||
|
||||
let sha256 = bytes.as_deref().map(sha256_hex);
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"domain": domain.as_str(),
|
||||
"revision": revision,
|
||||
"sha256": sha256,
|
||||
"source": source.info(),
|
||||
"value": json_value,
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConfigApplyRequest {
|
||||
domain: String,
|
||||
expected_revision: Option<u64>,
|
||||
reason: String,
|
||||
value: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConfigValidateRequest {
|
||||
domain: String,
|
||||
reason: String,
|
||||
value: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConfigRollbackRequest {
|
||||
domain: String,
|
||||
reason: String,
|
||||
}
|
||||
|
||||
fn parse_domain(domain: &str) -> Option<ConfigDomain> {
|
||||
match domain {
|
||||
"routing" => Some(ConfigDomain::Routing),
|
||||
"placement" => Some(ConfigDomain::Placement),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_config_validate(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<ConfigValidateRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let key = headers
|
||||
.get(HEADER_IDEMPOTENCY_KEY)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST);
|
||||
let key = match key {
|
||||
Ok(k) if !k.is_empty() => k,
|
||||
_ => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let Some(domain) = parse_domain(body.domain.as_str()) else {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
};
|
||||
|
||||
let engine = JobEngine::new(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id = match engine.start_config_validate(
|
||||
state.clone(),
|
||||
&principal,
|
||||
domain,
|
||||
body.reason,
|
||||
body.value,
|
||||
key,
|
||||
) {
|
||||
Ok(id) => id,
|
||||
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "job_id": job_id })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn start_config_apply(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<ConfigApplyRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let key = headers
|
||||
.get(HEADER_IDEMPOTENCY_KEY)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST);
|
||||
let key = match key {
|
||||
Ok(k) if !k.is_empty() => k,
|
||||
_ => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let Some(domain) = parse_domain(body.domain.as_str()) else {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
};
|
||||
|
||||
let engine = JobEngine::new(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id = match engine.start_config_apply(
|
||||
state.clone(),
|
||||
&principal,
|
||||
domain,
|
||||
body.reason,
|
||||
body.expected_revision,
|
||||
body.value,
|
||||
key,
|
||||
) {
|
||||
Ok(id) => id,
|
||||
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "job_id": job_id })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn start_config_rollback(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<ConfigRollbackRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let key = headers
|
||||
.get(HEADER_IDEMPOTENCY_KEY)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST);
|
||||
let key = match key {
|
||||
Ok(k) if !k.is_empty() => k,
|
||||
_ => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let Some(domain) = parse_domain(body.domain.as_str()) else {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
};
|
||||
|
||||
let engine = JobEngine::new(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id =
|
||||
match engine.start_config_rollback(state.clone(), &principal, domain, body.reason, key) {
|
||||
Ok(id) => id,
|
||||
Err(StartJobError::TenantLocked) => return StatusCode::CONFLICT.into_response(),
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({ "job_id": job_id })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ConfigPlanApplyRequest {
|
||||
domain: String,
|
||||
value: serde_json::Value,
|
||||
}
|
||||
|
||||
async fn plan_config_apply(
|
||||
State(state): State<AppState>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<ConfigPlanApplyRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
let domain = match body.domain.as_str() {
|
||||
"routing" => ConfigDomain::Routing,
|
||||
"placement" => ConfigDomain::Placement,
|
||||
_ => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
let Some(source) = state.config.source(domain) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
|
||||
// Validate proposed config (schema + semantics).
|
||||
let validate_res: Result<(), String> = match domain {
|
||||
ConfigDomain::Routing => {
|
||||
let cfg = match serde_json::from_value::<RoutingConfig>(body.value.clone()) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
validate_routing_semantics(&cfg)
|
||||
}
|
||||
ConfigDomain::Placement => {
|
||||
let cfg =
|
||||
match serde_json::from_value::<crate::placement::PlacementFile>(body.value.clone())
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
validate_placement_semantics(&cfg)
|
||||
}
|
||||
};
|
||||
if let Err(e) = validate_res {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let (cur_bytes, cur_rev) = match source.load_bytes().await {
|
||||
Ok(x) => x,
|
||||
Err(_) => return StatusCode::BAD_GATEWAY.into_response(),
|
||||
};
|
||||
let cur_value = cur_bytes
|
||||
.as_deref()
|
||||
.and_then(|b| serde_json::from_slice::<serde_json::Value>(b).ok())
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
let before = serde_json::to_string_pretty(&cur_value).unwrap_or_default();
|
||||
let after = serde_json::to_string_pretty(&body.value).unwrap_or_default();
|
||||
|
||||
let changed = cur_value != body.value;
|
||||
let impacted_services: Vec<&'static str> = match domain {
|
||||
ConfigDomain::Routing => vec!["gateway"],
|
||||
ConfigDomain::Placement => vec!["gateway", "control-api"],
|
||||
};
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"domain": domain.as_str(),
|
||||
"current_revision": cur_rev,
|
||||
"changed": changed,
|
||||
"impacted_services": impacted_services,
|
||||
"diff": {
|
||||
"before": before,
|
||||
"after": after,
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn sha256_hex(bytes: &[u8]) -> String {
|
||||
let mut h = sha2::Sha256::new();
|
||||
h.update(bytes);
|
||||
hex::encode(h.finalize())
|
||||
}
|
||||
|
||||
fn validate_routing_semantics(cfg: &RoutingConfig) -> Result<(), String> {
|
||||
let shard_maps = [
|
||||
("aggregate_shards", &cfg.aggregate_shards),
|
||||
("projection_shards", &cfg.projection_shards),
|
||||
("runner_shards", &cfg.runner_shards),
|
||||
];
|
||||
for (name, map) in shard_maps {
|
||||
for (shard_id, endpoints) in map {
|
||||
if endpoints.is_empty() {
|
||||
return Err(format!("{name}[{shard_id}] has no endpoints"));
|
||||
}
|
||||
for ep in endpoints {
|
||||
let u = Url::parse(ep)
|
||||
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
|
||||
if u.scheme() != "http" && u.scheme() != "https" {
|
||||
return Err(format!(
|
||||
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
|
||||
));
|
||||
}
|
||||
if u.host_str().is_none() {
|
||||
return Err(format!(
|
||||
"{name}[{shard_id}] endpoint {ep:?} must include host"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let placements = [
|
||||
(
|
||||
"aggregate_placement",
|
||||
&cfg.aggregate_placement,
|
||||
&cfg.aggregate_shards,
|
||||
),
|
||||
(
|
||||
"projection_placement",
|
||||
&cfg.projection_placement,
|
||||
&cfg.projection_shards,
|
||||
),
|
||||
(
|
||||
"runner_placement",
|
||||
&cfg.runner_placement,
|
||||
&cfg.runner_shards,
|
||||
),
|
||||
];
|
||||
for (pname, pmap, shards) in placements {
|
||||
for (tenant, shard_id) in pmap {
|
||||
if shard_id.trim().is_empty() {
|
||||
return Err(format!("{pname}[{tenant}] shard_id is empty"));
|
||||
}
|
||||
if !shards.contains_key(shard_id) {
|
||||
return Err(format!(
|
||||
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_placement_semantics(cfg: &crate::placement::PlacementFile) -> Result<(), String> {
|
||||
let kinds = [
|
||||
("aggregate_placement", cfg.aggregate_placement.as_ref()),
|
||||
("projection_placement", cfg.projection_placement.as_ref()),
|
||||
("runner_placement", cfg.runner_placement.as_ref()),
|
||||
];
|
||||
for (kind, k) in kinds {
|
||||
let Some(k) = k else { continue };
|
||||
for p in &k.placements {
|
||||
if p.targets.is_empty() {
|
||||
return Err(format!("{kind} tenant {} has no targets", p.tenant_id));
|
||||
}
|
||||
if p.targets.iter().any(|t| t.trim().is_empty()) {
|
||||
return Err(format!("{kind} tenant {} has empty target", p.tenant_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_tenants(
|
||||
State(state): State<AppState>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
@@ -256,6 +804,7 @@ async fn start_tenant_drain(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id = match engine.start_tenant_drain(
|
||||
state.clone(),
|
||||
@@ -298,6 +847,7 @@ async fn start_tenant_migrate(
|
||||
state.jobs.clone(),
|
||||
state.audit.clone(),
|
||||
state.tenant_locks.clone(),
|
||||
state.config_locks.clone(),
|
||||
);
|
||||
let job_id = match engine.start_tenant_migrate(
|
||||
state.clone(),
|
||||
|
||||
904
control/api/src/billing.rs
Normal file
904
control/api/src/billing.rs
Normal file
@@ -0,0 +1,904 @@
|
||||
use crate::{
|
||||
AppState,
|
||||
auth::{Principal, has_permission},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
Json,
|
||||
extract::{Extension, Path, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
fs,
|
||||
path::PathBuf,
|
||||
sync::{Arc, RwLock},
|
||||
time::SystemTime,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
|
||||
|
||||
fn verify_tenant_isolation(headers: &HeaderMap, path_tenant_id: Uuid) -> Result<(), StatusCode> {
|
||||
let header_tenant_id = headers
|
||||
.get(HEADER_TENANT_ID)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST)
|
||||
.and_then(|s| Uuid::parse_str(s).map_err(|_| StatusCode::BAD_REQUEST))?;
|
||||
|
||||
if header_tenant_id != path_tenant_id {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Plan {
|
||||
Free,
|
||||
Pro,
|
||||
Enterprise,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SubscriptionStatus {
|
||||
Trialing,
|
||||
Active,
|
||||
PastDue,
|
||||
Paused,
|
||||
Canceled,
|
||||
Incomplete,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Entitlements {
|
||||
pub max_deployments: u32,
|
||||
pub max_runners: u32,
|
||||
pub s3_docs_enabled: bool,
|
||||
pub support_tier: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum BillingEvent {
|
||||
SubscriptionCreated {
|
||||
tenant_id: Uuid,
|
||||
event_id: String,
|
||||
provider_customer_id: String,
|
||||
provider_subscription_id: String,
|
||||
status: SubscriptionStatus,
|
||||
plan: Plan,
|
||||
current_period_end: String,
|
||||
ts_ms: u64,
|
||||
},
|
||||
SubscriptionUpdated {
|
||||
tenant_id: Uuid,
|
||||
event_id: String,
|
||||
status: SubscriptionStatus,
|
||||
plan: Plan,
|
||||
current_period_end: String,
|
||||
cancel_at_period_end: bool,
|
||||
ts_ms: u64,
|
||||
},
|
||||
SubscriptionDeleted {
|
||||
tenant_id: Uuid,
|
||||
event_id: String,
|
||||
ts_ms: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl BillingEvent {
|
||||
pub fn tenant_id(&self) -> Uuid {
|
||||
match self {
|
||||
Self::SubscriptionCreated { tenant_id, .. } => *tenant_id,
|
||||
Self::SubscriptionUpdated { tenant_id, .. } => *tenant_id,
|
||||
Self::SubscriptionDeleted { tenant_id, .. } => *tenant_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn event_id(&self) -> &str {
|
||||
match self {
|
||||
Self::SubscriptionCreated { event_id, .. } => event_id,
|
||||
Self::SubscriptionUpdated { event_id, .. } => event_id,
|
||||
Self::SubscriptionDeleted { event_id, .. } => event_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ts_ms(&self) -> u64 {
|
||||
match self {
|
||||
Self::SubscriptionCreated { ts_ms, .. } => *ts_ms,
|
||||
Self::SubscriptionUpdated { ts_ms, .. } => *ts_ms,
|
||||
Self::SubscriptionDeleted { ts_ms, .. } => *ts_ms,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Entitlements {
|
||||
pub fn derive(plan: Option<&Plan>, status: Option<&SubscriptionStatus>) -> Self {
|
||||
let is_active = matches!(
|
||||
status,
|
||||
Some(SubscriptionStatus::Trialing | SubscriptionStatus::Active)
|
||||
);
|
||||
|
||||
if !is_active {
|
||||
return Self {
|
||||
max_deployments: 1,
|
||||
max_runners: 1,
|
||||
s3_docs_enabled: false,
|
||||
support_tier: "community".to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
match plan.unwrap_or(&Plan::Free) {
|
||||
Plan::Free => Self {
|
||||
max_deployments: 3,
|
||||
max_runners: 1,
|
||||
s3_docs_enabled: false,
|
||||
support_tier: "community".to_string(),
|
||||
},
|
||||
Plan::Pro => Self {
|
||||
max_deployments: 10,
|
||||
max_runners: 5,
|
||||
s3_docs_enabled: true,
|
||||
support_tier: "standard".to_string(),
|
||||
},
|
||||
Plan::Enterprise => Self {
|
||||
max_deployments: 1000,
|
||||
max_runners: 50,
|
||||
s3_docs_enabled: true,
|
||||
support_tier: "priority".to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TenantBillingState {
|
||||
pub provider: String,
|
||||
pub provider_customer_id: Option<String>,
|
||||
pub provider_subscription_id: Option<String>,
|
||||
pub provider_checkout_session_id: Option<String>,
|
||||
pub status: Option<SubscriptionStatus>,
|
||||
pub plan: Option<Plan>,
|
||||
pub current_period_end: Option<String>,
|
||||
pub cancel_at_period_end: Option<bool>,
|
||||
pub processed_webhook_event_ids: Vec<String>,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct BillingStateFile {
|
||||
pub revision: Option<String>,
|
||||
pub tenants: BTreeMap<Uuid, TenantBillingState>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct BillingResponse {
|
||||
pub configured: bool,
|
||||
pub provider: Option<String>,
|
||||
pub plan: Option<Plan>,
|
||||
pub status: Option<SubscriptionStatus>,
|
||||
pub current_period_end: Option<String>,
|
||||
pub cancel_at_period_end: Option<bool>,
|
||||
pub entitlements: Entitlements,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BillingStore {
|
||||
inner: Arc<RwLock<Inner>>,
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
path: PathBuf,
|
||||
last_modified: Option<SystemTime>,
|
||||
cached: Option<BillingStateFile>,
|
||||
}
|
||||
|
||||
impl BillingStore {
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(Inner {
|
||||
path,
|
||||
last_modified: None,
|
||||
cached: None,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_for_tenant(&self, tenant_id: Uuid) -> BillingResponse {
|
||||
let mut inner = self.inner.write().expect("billing lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
|
||||
if let Some(state) = inner
|
||||
.cached
|
||||
.as_ref()
|
||||
.and_then(|file| file.tenants.get(&tenant_id))
|
||||
{
|
||||
return BillingResponse {
|
||||
configured: true,
|
||||
provider: Some(state.provider.clone()),
|
||||
plan: state.plan.clone(),
|
||||
status: state.status.clone(),
|
||||
current_period_end: state.current_period_end.clone(),
|
||||
cancel_at_period_end: state.cancel_at_period_end,
|
||||
entitlements: Entitlements::derive(state.plan.as_ref(), state.status.as_ref()),
|
||||
};
|
||||
}
|
||||
|
||||
BillingResponse {
|
||||
configured: false,
|
||||
provider: None,
|
||||
plan: None,
|
||||
status: None,
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: None,
|
||||
entitlements: Entitlements::derive(None, None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_all_tenant_ids(&self) -> Vec<Uuid> {
|
||||
let mut inner = self.inner.write().expect("billing lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
|
||||
inner
|
||||
.cached
|
||||
.as_ref()
|
||||
.map(|f| f.tenants.keys().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn get_subscription_id(&self, tenant_id: Uuid) -> Option<String> {
|
||||
let mut inner = self.inner.write().expect("billing lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
|
||||
inner
|
||||
.cached
|
||||
.as_ref()
|
||||
.and_then(|f| f.tenants.get(&tenant_id))
|
||||
.and_then(|s| s.provider_subscription_id.clone())
|
||||
}
|
||||
|
||||
pub fn apply_event(&self, event: BillingEvent) -> Result<(), String> {
|
||||
let mut inner = self.inner.write().expect("billing lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
|
||||
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
|
||||
revision: Some("dev".to_string()),
|
||||
tenants: BTreeMap::new(),
|
||||
});
|
||||
|
||||
let tenant_id = event.tenant_id();
|
||||
let event_id = event.event_id().to_string();
|
||||
let ts_ms = event.ts_ms();
|
||||
|
||||
let state = file.tenants.entry(tenant_id).or_insert(TenantBillingState {
|
||||
provider: "unknown".to_string(), // Will be updated by Created event
|
||||
provider_customer_id: None,
|
||||
provider_subscription_id: None,
|
||||
provider_checkout_session_id: None,
|
||||
status: None,
|
||||
plan: None,
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: None,
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 0,
|
||||
});
|
||||
|
||||
// Deduplication
|
||||
if state.processed_webhook_event_ids.contains(&event_id) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Monotonicity check
|
||||
if state.updated_at > ts_ms {
|
||||
state.processed_webhook_event_ids.push(event_id);
|
||||
state.processed_webhook_event_ids.truncate(50);
|
||||
inner.save(file)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match event {
|
||||
BillingEvent::SubscriptionCreated {
|
||||
provider_customer_id,
|
||||
provider_subscription_id,
|
||||
status,
|
||||
plan,
|
||||
current_period_end,
|
||||
..
|
||||
} => {
|
||||
state.provider_customer_id = Some(provider_customer_id);
|
||||
state.provider_subscription_id = Some(provider_subscription_id);
|
||||
state.status = Some(status);
|
||||
state.plan = Some(plan);
|
||||
state.current_period_end = Some(current_period_end);
|
||||
}
|
||||
BillingEvent::SubscriptionUpdated {
|
||||
status,
|
||||
plan,
|
||||
current_period_end,
|
||||
cancel_at_period_end,
|
||||
..
|
||||
} => {
|
||||
state.status = Some(status);
|
||||
state.plan = Some(plan);
|
||||
state.current_period_end = Some(current_period_end);
|
||||
state.cancel_at_period_end = Some(cancel_at_period_end);
|
||||
}
|
||||
BillingEvent::SubscriptionDeleted { .. } => {
|
||||
state.status = Some(SubscriptionStatus::Canceled);
|
||||
}
|
||||
}
|
||||
|
||||
state.updated_at = ts_ms;
|
||||
state.processed_webhook_event_ids.push(event_id);
|
||||
state.processed_webhook_event_ids.truncate(50);
|
||||
|
||||
inner.save(file)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn update_tenant_state(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
state: TenantBillingState,
|
||||
) -> Result<String, String> {
|
||||
let mut inner = self.inner.write().expect("billing lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
|
||||
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
|
||||
revision: Some("dev".to_string()),
|
||||
tenants: BTreeMap::new(),
|
||||
});
|
||||
|
||||
file.tenants.insert(tenant_id, state);
|
||||
inner.save(file)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn save(&mut self, mut file: BillingStateFile) -> Result<String, String> {
|
||||
let revision = format!("rev-{}", Uuid::new_v4());
|
||||
file.revision = Some(revision.clone());
|
||||
|
||||
let raw = serde_json::to_string_pretty(&file).map_err(|e| e.to_string())?;
|
||||
let tmp = self.path.with_extension("json.tmp");
|
||||
fs::write(&tmp, raw).map_err(|e| e.to_string())?;
|
||||
fs::rename(&tmp, &self.path).map_err(|e| e.to_string())?;
|
||||
|
||||
self.last_modified = None;
|
||||
self.cached = Some(file);
|
||||
|
||||
Ok(revision)
|
||||
}
|
||||
|
||||
fn reload_if_changed(&mut self) {
|
||||
let meta = fs::metadata(&self.path).ok();
|
||||
let modified = meta.and_then(|m| m.modified().ok());
|
||||
|
||||
if self.cached.is_some() && modified.is_some() && modified == self.last_modified {
|
||||
return;
|
||||
}
|
||||
|
||||
self.last_modified = modified;
|
||||
let p = &self.path;
|
||||
self.cached = fs::read_to_string(p)
|
||||
.ok()
|
||||
.and_then(|raw| serde_json::from_str(&raw).ok());
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_billing(
|
||||
State(state): State<AppState>,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
let resp = state.billing.get_for_tenant(tenant_id);
|
||||
(StatusCode::OK, Json(resp)).into_response()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CheckoutRequest {
|
||||
pub plan: Plan,
|
||||
pub return_path: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn checkout(
|
||||
State(state): State<AppState>,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Json(body): Json<CheckoutRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
// Check if subscription already exists and is active/trialing
|
||||
let current = state.billing.get_for_tenant(tenant_id);
|
||||
if current.configured
|
||||
&& matches!(
|
||||
current.status,
|
||||
Some(SubscriptionStatus::Active | SubscriptionStatus::Trialing)
|
||||
)
|
||||
{
|
||||
return (
|
||||
StatusCode::CONFLICT,
|
||||
Json(serde_json::json!({ "error": "tenant already has an active subscription" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Construct full return URL
|
||||
// TODO: Validate return_path against ALLOWED_RETURN_ORIGINS if provided
|
||||
let return_url = body.return_path.unwrap_or_else(|| "/billing".to_string());
|
||||
|
||||
match state
|
||||
.billing_provider
|
||||
.create_checkout_session(tenant_id, body.plan, return_url)
|
||||
.await
|
||||
{
|
||||
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": err_msg })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn portal(
|
||||
State(state): State<AppState>,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
headers: HeaderMap,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
let return_url = "/billing".to_string();
|
||||
match state
|
||||
.billing_provider
|
||||
.create_portal_session(tenant_id, return_url)
|
||||
.await
|
||||
{
|
||||
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": err_msg })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn webhook(
|
||||
State(state): State<AppState>,
|
||||
Path(_provider): Path<String>,
|
||||
headers: HeaderMap,
|
||||
body: axum::body::Bytes,
|
||||
) -> impl IntoResponse {
|
||||
// Note: We don't require auth here as this is a public endpoint called by the provider.
|
||||
// Security is handled via signature verification in the provider trait.
|
||||
|
||||
match state.billing_provider.verify_webhook(&body, &headers).await {
|
||||
Ok(event) => {
|
||||
metrics::counter!("billing_webhook_requests_total", "status" => "success").increment(1);
|
||||
if let Err(e) = state.billing.apply_event(event) {
|
||||
tracing::error!(error = %e, "failed to apply billing event from webhook");
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
metrics::counter!("billing_webhook_requests_total", "status" => "error").increment(1);
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_reconciliation_loop(state: AppState) {
|
||||
let interval_secs = std::env::var("CONTROL_BILLING_RECONCILE_INTERVAL_SECS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(3600);
|
||||
|
||||
tracing::info!(interval_secs, "starting billing reconciliation loop");
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
||||
|
||||
tracing::info!("starting billing reconciliation run");
|
||||
reconcile_once(&state).await;
|
||||
|
||||
// Update tenant status gauges
|
||||
// Note: This is an expensive operation if there are many tenants,
|
||||
// but for reconciliation it's fine once per hour.
|
||||
update_billing_gauges(&state);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn reconcile_once(state: &AppState) {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let tenant_ids = state.billing.get_all_tenant_ids();
|
||||
let mut success = 0;
|
||||
let mut error = 0;
|
||||
let mut skipped = 0;
|
||||
|
||||
for tenant_id in tenant_ids {
|
||||
let sub_id = state.billing.get_subscription_id(tenant_id);
|
||||
if let Some(subscription_id) = sub_id {
|
||||
match state
|
||||
.billing_provider
|
||||
.fetch_subscription(tenant_id, &subscription_id)
|
||||
.await
|
||||
{
|
||||
Ok(event) => {
|
||||
if let Err(e) = state.billing.apply_event(event) {
|
||||
tracing::error!(?tenant_id, error = %e, "failed to apply reconciled billing event");
|
||||
error += 1;
|
||||
} else {
|
||||
success += 1;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(?tenant_id, error = %e, "failed to fetch subscription for reconciliation");
|
||||
error += 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
skipped += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
metrics::counter!("billing_reconciliation_runs_total", "result" => "done").increment(1);
|
||||
metrics::histogram!("billing_reconciliation_duration_ms").record(elapsed.as_millis() as f64);
|
||||
|
||||
tracing::info!(
|
||||
success,
|
||||
error,
|
||||
skipped,
|
||||
duration_ms = elapsed.as_millis(),
|
||||
"billing reconciliation run complete"
|
||||
);
|
||||
}
|
||||
|
||||
fn update_billing_gauges(state: &AppState) {
|
||||
let tenant_ids = state.billing.get_all_tenant_ids();
|
||||
let mut counts: BTreeMap<(String, String), u64> = BTreeMap::new();
|
||||
|
||||
for tenant_id in tenant_ids {
|
||||
let resp = state.billing.get_for_tenant(tenant_id);
|
||||
let plan = match resp.plan {
|
||||
Some(Plan::Free) => "free",
|
||||
Some(Plan::Pro) => "pro",
|
||||
Some(Plan::Enterprise) => "enterprise",
|
||||
None => "none",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let status = match resp.status {
|
||||
Some(SubscriptionStatus::Active) => "active",
|
||||
Some(SubscriptionStatus::Trialing) => "trialing",
|
||||
Some(SubscriptionStatus::PastDue) => "past_due",
|
||||
Some(SubscriptionStatus::Paused) => "paused",
|
||||
Some(SubscriptionStatus::Canceled) => "canceled",
|
||||
Some(SubscriptionStatus::Incomplete) => "incomplete",
|
||||
None => "none",
|
||||
}
|
||||
.to_string();
|
||||
|
||||
*counts.entry((plan, status)).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
for ((plan, status), count) in counts {
|
||||
metrics::gauge!("billing_tenant_status_count", "plan" => plan, "status" => status)
|
||||
.set(count as f64);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum BillingError {
|
||||
#[error("provider error: {0}")]
|
||||
Provider(String),
|
||||
#[error("invalid configuration: {0}")]
|
||||
Config(String),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait BillingProvider: Send + Sync {
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
plan: Plan,
|
||||
return_url: String,
|
||||
) -> Result<String, BillingError>;
|
||||
|
||||
async fn create_portal_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
return_url: String,
|
||||
) -> Result<String, BillingError>;
|
||||
|
||||
async fn verify_webhook(
|
||||
&self,
|
||||
payload: &[u8],
|
||||
headers: &HeaderMap,
|
||||
) -> Result<BillingEvent, BillingError>;
|
||||
|
||||
async fn fetch_subscription(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
subscription_id: &str,
|
||||
) -> Result<BillingEvent, BillingError>;
|
||||
}
|
||||
|
||||
pub struct StripeProvider {
|
||||
pub secret_key: String,
|
||||
pub price_pro: String,
|
||||
pub price_enterprise: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BillingProvider for StripeProvider {
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
plan: Plan,
|
||||
_return_url: String,
|
||||
) -> Result<String, BillingError> {
|
||||
let _price = match plan {
|
||||
Plan::Pro => &self.price_pro,
|
||||
Plan::Enterprise => &self.price_enterprise,
|
||||
Plan::Free => {
|
||||
return Err(BillingError::Config(
|
||||
"Free plan has no checkout".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Actually call Stripe API
|
||||
// For now, returning a simulated Stripe checkout URL
|
||||
Ok(format!(
|
||||
"https://checkout.stripe.com/pay/cs_test_{}?tenant_id={}",
|
||||
Uuid::new_v4(),
|
||||
tenant_id
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_portal_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
_return_url: String,
|
||||
) -> Result<String, BillingError> {
|
||||
// TODO: Actually call Stripe API
|
||||
Ok(format!(
|
||||
"https://billing.stripe.com/p/session/ps_test_{}?tenant_id={}",
|
||||
Uuid::new_v4(),
|
||||
tenant_id
|
||||
))
|
||||
}
|
||||
|
||||
async fn verify_webhook(
|
||||
&self,
|
||||
_payload: &[u8],
|
||||
_headers: &HeaderMap,
|
||||
) -> Result<BillingEvent, BillingError> {
|
||||
// TODO: Implement real Stripe signature verification
|
||||
Err(BillingError::Provider("Not implemented".to_string()))
|
||||
}
|
||||
|
||||
async fn fetch_subscription(
|
||||
&self,
|
||||
_tenant_id: Uuid,
|
||||
_subscription_id: &str,
|
||||
) -> Result<BillingEvent, BillingError> {
|
||||
// TODO: Actually call Stripe API with timeout
|
||||
// let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()...
|
||||
Err(BillingError::Provider("Not implemented".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl BillingProvider for MockProvider {
|
||||
async fn create_checkout_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
_plan: Plan,
|
||||
_return_url: String,
|
||||
) -> Result<String, BillingError> {
|
||||
Ok(format!("https://mock.stripe.com/checkout/{}", tenant_id))
|
||||
}
|
||||
|
||||
async fn create_portal_session(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
_return_url: String,
|
||||
) -> Result<String, BillingError> {
|
||||
Ok(format!("https://mock.stripe.com/portal/{}", tenant_id))
|
||||
}
|
||||
|
||||
async fn verify_webhook(
|
||||
&self,
|
||||
payload: &[u8],
|
||||
_headers: &HeaderMap,
|
||||
) -> Result<BillingEvent, BillingError> {
|
||||
// Mock implementation: just parse the payload as a BillingEvent
|
||||
serde_json::from_slice(payload).map_err(|e| BillingError::Provider(e.to_string()))
|
||||
}
|
||||
|
||||
async fn fetch_subscription(
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
_subscription_id: &str,
|
||||
) -> Result<BillingEvent, BillingError> {
|
||||
// Mock implementation: return a SubscriptionUpdated event with current state
|
||||
// In a real mock we might want to store expectations, but for now we just return something plausible.
|
||||
Ok(BillingEvent::SubscriptionUpdated {
|
||||
tenant_id,
|
||||
event_id: format!("reconcile-{}", Uuid::new_v4()),
|
||||
status: SubscriptionStatus::Active,
|
||||
plan: Plan::Pro,
|
||||
current_period_end: "2099-12-31T23:59:59Z".to_string(),
|
||||
cancel_at_period_end: false,
|
||||
ts_ms: SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as u64,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
pub fn get_checkout_url(tenant: Uuid) -> String {
|
||||
format!("https://mock.stripe.com/checkout/{}", tenant)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env::temp_dir;
|
||||
|
||||
#[test]
|
||||
fn test_entitlement_derivation() {
|
||||
let e = Entitlements::derive(Some(&Plan::Free), Some(&SubscriptionStatus::PastDue));
|
||||
assert_eq!(e.max_deployments, 1);
|
||||
|
||||
let e = Entitlements::derive(Some(&Plan::Pro), Some(&SubscriptionStatus::Active));
|
||||
assert_eq!(e.max_deployments, 10);
|
||||
assert!(e.s3_docs_enabled);
|
||||
|
||||
let e = Entitlements::derive(Some(&Plan::Enterprise), Some(&SubscriptionStatus::Trialing));
|
||||
assert_eq!(e.max_deployments, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_billing_state_roundtrip() {
|
||||
let mut path = temp_dir();
|
||||
path.push(format!("billing-{}.json", Uuid::new_v4()));
|
||||
|
||||
let store = BillingStore::new(path.clone());
|
||||
let tenant_id = Uuid::new_v4();
|
||||
|
||||
let resp = store.get_for_tenant(tenant_id);
|
||||
assert!(!resp.configured);
|
||||
assert_eq!(resp.entitlements.max_deployments, 1);
|
||||
|
||||
let state = TenantBillingState {
|
||||
provider: "mock".to_string(),
|
||||
provider_customer_id: None,
|
||||
provider_subscription_id: None,
|
||||
provider_checkout_session_id: None,
|
||||
status: Some(SubscriptionStatus::Active),
|
||||
plan: Some(Plan::Pro),
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: Some(false),
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 0,
|
||||
};
|
||||
|
||||
store.update_tenant_state(tenant_id, state).unwrap();
|
||||
|
||||
let resp2 = store.get_for_tenant(tenant_id);
|
||||
assert!(resp2.configured);
|
||||
assert_eq!(resp2.provider.as_deref(), Some("mock"));
|
||||
assert_eq!(resp2.plan, Some(Plan::Pro));
|
||||
assert_eq!(resp2.entitlements.max_deployments, 10);
|
||||
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reconciliation_corrects_state() {
|
||||
let mut path = temp_dir();
|
||||
path.push(format!("billing-reconcile-{}.json", Uuid::new_v4()));
|
||||
let store = BillingStore::new(path.clone());
|
||||
let tenant_id = Uuid::new_v4();
|
||||
|
||||
// 1. Initial state: PastDue
|
||||
store
|
||||
.update_tenant_state(
|
||||
tenant_id,
|
||||
TenantBillingState {
|
||||
provider: "mock".to_string(),
|
||||
provider_customer_id: Some("cus_1".to_string()),
|
||||
provider_subscription_id: Some("sub_1".to_string()),
|
||||
provider_checkout_session_id: None,
|
||||
status: Some(SubscriptionStatus::PastDue),
|
||||
plan: Some(Plan::Pro),
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: Some(false),
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 100,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let state = AppState {
|
||||
prometheus: crate::get_test_prometheus_handle(),
|
||||
auth: crate::AuthConfig { hs256_secret: None },
|
||||
jobs: crate::jobs::JobStore::default(),
|
||||
audit: crate::AuditStore::default(),
|
||||
tenant_locks: crate::job_engine::TenantLocks::default(),
|
||||
config_locks: crate::job_engine::ConfigLocks::default(),
|
||||
http: reqwest::Client::new(),
|
||||
placement: crate::placement::PlacementStore::new(temp_dir().join("placement.json")),
|
||||
billing: store.clone(),
|
||||
billing_provider: Arc::new(MockProvider),
|
||||
billing_enforcement_enabled: true,
|
||||
config: crate::config_registry::ConfigRegistry::new(None, None),
|
||||
fleet_services: vec![],
|
||||
swarm: crate::swarm::SwarmStore::new(temp_dir().join("swarm.json")),
|
||||
docs: None,
|
||||
};
|
||||
|
||||
// 2. Run reconciliation. MockProvider returns Active status.
|
||||
reconcile_once(&state).await;
|
||||
|
||||
// 3. Verify state is now Active
|
||||
let resp = store.get_for_tenant(tenant_id);
|
||||
assert_eq!(resp.status, Some(SubscriptionStatus::Active));
|
||||
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
323
control/api/src/config_registry.rs
Normal file
323
control/api/src/config_registry.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc, time::Duration};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ConfigDomain {
|
||||
Routing,
|
||||
Placement,
|
||||
}
|
||||
|
||||
impl ConfigDomain {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ConfigDomain::Routing => "routing",
|
||||
ConfigDomain::Placement => "placement",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConfigRegistryError {
|
||||
#[error("source error: {0}")]
|
||||
Source(String),
|
||||
#[error("decode error: {0}")]
|
||||
Decode(String),
|
||||
#[error("domain not configured")]
|
||||
NotConfigured,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ConfigSnapshot<T> {
|
||||
pub domain: String,
|
||||
pub revision: u64,
|
||||
pub value: T,
|
||||
pub source: ConfigSourceInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
pub enum ConfigSourceInfo {
|
||||
File { path: String },
|
||||
NatsKv { bucket: String, key: String },
|
||||
Fixed,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConfigSource: Send + Sync {
|
||||
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError>;
|
||||
async fn put_bytes(
|
||||
&self,
|
||||
expected_revision: Option<u64>,
|
||||
value: Vec<u8>,
|
||||
) -> Result<u64, ConfigRegistryError>;
|
||||
async fn history_bytes(&self, limit: usize)
|
||||
-> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError>;
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
|
||||
ConfigRegistryError,
|
||||
>;
|
||||
fn info(&self) -> ConfigSourceInfo;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FixedSource {
|
||||
bytes: Arc<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl FixedSource {
|
||||
pub fn new(bytes: Vec<u8>) -> Self {
|
||||
Self {
|
||||
bytes: Arc::new(bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigSource for FixedSource {
|
||||
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
|
||||
Ok((Some(self.bytes.as_ref().clone()), 1))
|
||||
}
|
||||
|
||||
async fn put_bytes(
|
||||
&self,
|
||||
_expected_revision: Option<u64>,
|
||||
_value: Vec<u8>,
|
||||
) -> Result<u64, ConfigRegistryError> {
|
||||
Err(ConfigRegistryError::Source(
|
||||
"fixed source is read-only".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn history_bytes(
|
||||
&self,
|
||||
_limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
|
||||
Err(ConfigRegistryError::Source(
|
||||
"fixed source has no history".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
|
||||
ConfigRegistryError,
|
||||
> {
|
||||
Ok(Box::pin(futures::stream::empty()))
|
||||
}
|
||||
|
||||
fn info(&self) -> ConfigSourceInfo {
|
||||
ConfigSourceInfo::Fixed
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileSource {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl FileSource {
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self { path }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigSource for FileSource {
|
||||
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
|
||||
let raw = tokio::fs::read(&self.path)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
Ok((Some(raw), 0))
|
||||
}
|
||||
|
||||
async fn put_bytes(
|
||||
&self,
|
||||
_expected_revision: Option<u64>,
|
||||
value: Vec<u8>,
|
||||
) -> Result<u64, ConfigRegistryError> {
|
||||
let tmp = self.path.with_extension("tmp");
|
||||
tokio::fs::write(&tmp, &value)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
tokio::fs::rename(&tmp, &self.path)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn history_bytes(
|
||||
&self,
|
||||
_limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
|
||||
Err(ConfigRegistryError::Source(
|
||||
"file source has no history".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
|
||||
ConfigRegistryError,
|
||||
> {
|
||||
Ok(Box::pin(futures::stream::empty()))
|
||||
}
|
||||
|
||||
fn info(&self) -> ConfigSourceInfo {
|
||||
ConfigSourceInfo::File {
|
||||
path: self.path.to_string_lossy().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NatsKvSource {
|
||||
kv: async_nats::jetstream::kv::Store,
|
||||
bucket: String,
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl NatsKvSource {
|
||||
pub async fn connect(
|
||||
nats_url: impl Into<String>,
|
||||
bucket: impl Into<String>,
|
||||
key: impl Into<String>,
|
||||
) -> Result<Self, ConfigRegistryError> {
|
||||
let nats_url = nats_url.into();
|
||||
let bucket = bucket.into();
|
||||
let key = key.into();
|
||||
|
||||
let client = tokio::time::timeout(Duration::from_secs(2), async_nats::connect(nats_url))
|
||||
.await
|
||||
.map_err(|_| ConfigRegistryError::Source("connect timeout".to_string()))?
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
let jetstream = async_nats::jetstream::new(client);
|
||||
let kv = match jetstream.get_key_value(&bucket).await {
|
||||
Ok(kv) => kv,
|
||||
Err(_) => jetstream
|
||||
.create_key_value(async_nats::jetstream::kv::Config {
|
||||
bucket: bucket.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
|
||||
};
|
||||
|
||||
Ok(Self { kv, bucket, key })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigSource for NatsKvSource {
|
||||
async fn load_bytes(&self) -> Result<(Option<Vec<u8>>, u64), ConfigRegistryError> {
|
||||
let entry = self
|
||||
.kv
|
||||
.entry(&self.key)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
Ok(match entry {
|
||||
Some(e) => (Some(e.value.to_vec()), e.revision),
|
||||
None => (None, 0),
|
||||
})
|
||||
}
|
||||
|
||||
async fn put_bytes(
|
||||
&self,
|
||||
expected_revision: Option<u64>,
|
||||
value: Vec<u8>,
|
||||
) -> Result<u64, ConfigRegistryError> {
|
||||
let rev = match expected_revision {
|
||||
Some(expected) if expected > 0 => self
|
||||
.kv
|
||||
.update(&self.key, value.into(), expected)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
|
||||
_ => self
|
||||
.kv
|
||||
.put(&self.key, value.into())
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
|
||||
};
|
||||
Ok(rev)
|
||||
}
|
||||
|
||||
async fn history_bytes(
|
||||
&self,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(u64, Vec<u8>)>, ConfigRegistryError> {
|
||||
let mut stream = self
|
||||
.kv
|
||||
.history(&self.key)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
let mut out = Vec::new();
|
||||
while let Some(item) = stream.next().await {
|
||||
let entry = item.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
out.push((entry.revision, entry.value.to_vec()));
|
||||
if out.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), ConfigRegistryError>> + Send>>,
|
||||
ConfigRegistryError,
|
||||
> {
|
||||
let key = self.key.clone();
|
||||
let watch = self
|
||||
.kv
|
||||
.watch(&key)
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?;
|
||||
Ok(Box::pin(watch.filter_map(|entry| async move {
|
||||
match entry {
|
||||
Ok(entry) => match entry.operation {
|
||||
async_nats::jetstream::kv::Operation::Put => Some(Ok(())),
|
||||
async_nats::jetstream::kv::Operation::Delete
|
||||
| async_nats::jetstream::kv::Operation::Purge => None,
|
||||
},
|
||||
Err(e) => Some(Err(ConfigRegistryError::Source(e.to_string()))),
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
fn info(&self) -> ConfigSourceInfo {
|
||||
ConfigSourceInfo::NatsKv {
|
||||
bucket: self.bucket.clone(),
|
||||
key: self.key.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConfigRegistry {
|
||||
routing: Option<Arc<dyn ConfigSource>>,
|
||||
placement: Option<Arc<dyn ConfigSource>>,
|
||||
}
|
||||
|
||||
impl ConfigRegistry {
|
||||
pub fn new(
|
||||
routing: Option<Arc<dyn ConfigSource>>,
|
||||
placement: Option<Arc<dyn ConfigSource>>,
|
||||
) -> Self {
|
||||
Self { routing, placement }
|
||||
}
|
||||
|
||||
pub fn source(&self, domain: ConfigDomain) -> Option<Arc<dyn ConfigSource>> {
|
||||
match domain {
|
||||
ConfigDomain::Routing => self.routing.clone(),
|
||||
ConfigDomain::Placement => self.placement.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
15
control/api/src/config_schemas.rs
Normal file
15
control/api/src/config_schemas.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RoutingConfig {
|
||||
pub revision: u64,
|
||||
|
||||
pub aggregate_placement: HashMap<String, String>,
|
||||
pub projection_placement: HashMap<String, String>,
|
||||
pub runner_placement: HashMap<String, String>,
|
||||
|
||||
pub aggregate_shards: HashMap<String, Vec<String>>,
|
||||
pub projection_shards: HashMap<String, Vec<String>>,
|
||||
pub runner_shards: HashMap<String, Vec<String>>,
|
||||
}
|
||||
353
control/api/src/documents.rs
Normal file
353
control/api/src/documents.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
use crate::auth::{Principal, has_permission};
|
||||
use crate::{AppState, RequestIds};
|
||||
use axum::{
|
||||
Router,
|
||||
body::Bytes,
|
||||
extract::{Extension, Path, Query, State},
|
||||
http::{HeaderMap, StatusCode, header},
|
||||
response::IntoResponse,
|
||||
routing::{get, post, put},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/tenants/{tenant_id}/docs", get(list_docs))
|
||||
.route(
|
||||
"/tenants/{tenant_id}/docs/{doc_type}/{doc_id}/{filename}",
|
||||
put(upload_doc),
|
||||
)
|
||||
.route(
|
||||
"/tenants/{tenant_id}/docs/object/{*key}",
|
||||
get(get_doc).delete(delete_doc),
|
||||
)
|
||||
.route(
|
||||
"/tenants/{tenant_id}/docs/presign/upload",
|
||||
post(presign_upload),
|
||||
)
|
||||
.route(
|
||||
"/tenants/{tenant_id}/docs/presign/download",
|
||||
post(presign_download),
|
||||
)
|
||||
}
|
||||
|
||||
fn ensure_tenant_header(headers: &HeaderMap, tenant_id: Uuid) -> Result<(), StatusCode> {
|
||||
let header_tid = headers
|
||||
.get(HEADER_TENANT_ID)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(StatusCode::BAD_REQUEST)?;
|
||||
let header_tid = Uuid::parse_str(header_tid).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
if header_tid != tenant_id {
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_docs_enabled(state: &AppState, tenant_id: Uuid) -> Result<(), StatusCode> {
|
||||
if !state.billing_enforcement_enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
|
||||
if !entitlements.s3_docs_enabled {
|
||||
return Err(StatusCode::PAYMENT_REQUIRED);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListQuery {
|
||||
prefix: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ListResponse {
|
||||
objects: Vec<crate::s3_docs::DocObject>,
|
||||
}
|
||||
|
||||
async fn list_docs(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
Query(q): Query<ListQuery>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
let prefix = q.prefix.unwrap_or_default();
|
||||
let prefix = prefix.trim();
|
||||
if prefix.contains("..") {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
}
|
||||
let base = format!("{}{}", store_prefix(store), tenant_id);
|
||||
let prefix = if prefix.is_empty() {
|
||||
format!("{base}/")
|
||||
} else {
|
||||
format!("{base}/{prefix}")
|
||||
};
|
||||
match store.list_for_tenant(&tenant_id.to_string(), &prefix).await {
|
||||
Ok(objects) => (StatusCode::OK, axum::Json(ListResponse { objects })).into_response(),
|
||||
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn upload_doc(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path((tenant_id, doc_type, doc_id, filename)): Path<(Uuid, String, String, String)>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
Extension(request_ids): Extension<RequestIds>,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
|
||||
let ct = headers
|
||||
.get(header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let key = match store.key_for(&tenant_id.to_string(), &doc_type, &doc_id, &filename) {
|
||||
Ok(k) => k,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
let bytes = body.to_vec();
|
||||
let hash = crate::s3_docs::DocsStore::content_hash_sha256_hex(&bytes);
|
||||
if let Err(e) = store
|
||||
.put_for_tenant(&tenant_id.to_string(), &key, bytes, ct)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
request_id = %request_ids.request_id,
|
||||
correlation_id = ?request_ids.correlation_id,
|
||||
error = %e,
|
||||
"docs upload failed"
|
||||
);
|
||||
return StatusCode::BAD_GATEWAY.into_response();
|
||||
}
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
axum::Json(serde_json::json!({
|
||||
"key": key,
|
||||
"sha256": hash,
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_doc(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path((tenant_id, key)): Path<(Uuid, String)>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
|
||||
let base = format!("{}{}", store_prefix(store), tenant_id);
|
||||
if !key.starts_with(&base) {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
match store
|
||||
.get_bytes_for_tenant(&tenant_id.to_string(), &key)
|
||||
.await
|
||||
{
|
||||
Ok((bytes, ct)) => {
|
||||
let mut res = axum::response::Response::new(axum::body::Body::from(bytes));
|
||||
*res.status_mut() = StatusCode::OK;
|
||||
if let Some(ct) = ct
|
||||
&& let Ok(v) = axum::http::HeaderValue::from_str(&ct)
|
||||
{
|
||||
res.headers_mut().insert(header::CONTENT_TYPE, v);
|
||||
}
|
||||
res
|
||||
}
|
||||
Err(_) => StatusCode::NOT_FOUND.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_doc(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path((tenant_id, key)): Path<(Uuid, String)>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
|
||||
let base = format!("{}{}", store_prefix(store), tenant_id);
|
||||
if !key.starts_with(&base) {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
|
||||
match store.delete_for_tenant(&tenant_id.to_string(), &key).await {
|
||||
Ok(_) => StatusCode::NO_CONTENT.into_response(),
|
||||
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PresignUploadRequest {
|
||||
doc_type: String,
|
||||
doc_id: Option<String>,
|
||||
filename: String,
|
||||
content_type: Option<String>,
|
||||
}
|
||||
|
||||
async fn presign_upload(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
axum::Json(body): axum::Json<PresignUploadRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:write") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
|
||||
let doc_id = body.doc_id.unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
let key = match store.key_for(
|
||||
&tenant_id.to_string(),
|
||||
&body.doc_type,
|
||||
&doc_id,
|
||||
&body.filename,
|
||||
) {
|
||||
Ok(k) => k,
|
||||
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||
};
|
||||
|
||||
match store
|
||||
.presign_put_for_tenant(
|
||||
&tenant_id.to_string(),
|
||||
&key,
|
||||
body.content_type,
|
||||
std::time::Duration::from_secs(300),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(url) => (
|
||||
StatusCode::OK,
|
||||
axum::Json(serde_json::json!({
|
||||
"method": "PUT",
|
||||
"url": url,
|
||||
"key": key,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PresignDownloadRequest {
|
||||
key: String,
|
||||
}
|
||||
|
||||
async fn presign_download(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(tenant_id): Path<Uuid>,
|
||||
Extension(principal): Extension<Principal>,
|
||||
axum::Json(body): axum::Json<PresignDownloadRequest>,
|
||||
) -> impl IntoResponse {
|
||||
if !has_permission(&principal, "control:read") {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_tenant_header(&headers, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
if let Err(s) = ensure_docs_enabled(&state, tenant_id) {
|
||||
return s.into_response();
|
||||
}
|
||||
let store = match state.docs.as_ref() {
|
||||
Some(s) => s,
|
||||
None => return StatusCode::SERVICE_UNAVAILABLE.into_response(),
|
||||
};
|
||||
let base = format!("{}{}", store_prefix(store), tenant_id);
|
||||
if !body.key.starts_with(&base) {
|
||||
return StatusCode::FORBIDDEN.into_response();
|
||||
}
|
||||
match store
|
||||
.presign_get_for_tenant(
|
||||
&tenant_id.to_string(),
|
||||
&body.key,
|
||||
std::time::Duration::from_secs(300),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(url) => (
|
||||
StatusCode::OK,
|
||||
axum::Json(serde_json::json!({
|
||||
"method": "GET",
|
||||
"url": url,
|
||||
"key": body.key,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
Err(_) => StatusCode::BAD_GATEWAY.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn store_prefix(store: &crate::s3_docs::DocsStore) -> &str {
|
||||
store.prefix()
|
||||
}
|
||||
127
control/api/src/drift.rs
Normal file
127
control/api/src/drift.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use crate::{AppState, build_info::extract_build_info, fleet, swarm::SwarmService};
|
||||
use serde::Serialize;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DriftKind {
|
||||
Missing,
|
||||
Extra,
|
||||
Unhealthy,
|
||||
VersionMismatch,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct DriftItem {
|
||||
pub kind: DriftKind,
|
||||
pub service: String,
|
||||
pub details: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct DriftResponse {
|
||||
pub summary: BTreeMap<String, u64>,
|
||||
pub items: Vec<DriftItem>,
|
||||
}
|
||||
|
||||
pub async fn compute(state: &AppState) -> DriftResponse {
|
||||
let mut items: Vec<DriftItem> = Vec::new();
|
||||
|
||||
// Desired service set: what the Control API was configured to observe.
|
||||
// (In production, this should evolve into "desired stacks + required services".)
|
||||
let desired: BTreeSet<String> = state
|
||||
.fleet_services
|
||||
.iter()
|
||||
.map(|s| s.name.clone())
|
||||
.collect();
|
||||
|
||||
// Observed service set: what Swarm reports (dev: from file snapshot).
|
||||
let observed_services: Vec<SwarmService> = state.swarm.list_services();
|
||||
let observed: BTreeSet<String> = observed_services.iter().map(|s| s.name.clone()).collect();
|
||||
|
||||
for missing in desired.difference(&observed) {
|
||||
items.push(DriftItem {
|
||||
kind: DriftKind::Missing,
|
||||
service: missing.clone(),
|
||||
details: serde_json::json!({ "expected": true }),
|
||||
});
|
||||
}
|
||||
|
||||
for extra in observed.difference(&desired) {
|
||||
items.push(DriftItem {
|
||||
kind: DriftKind::Extra,
|
||||
service: extra.clone(),
|
||||
details: serde_json::json!({ "observed": true }),
|
||||
});
|
||||
}
|
||||
|
||||
// Health drift: based on fleet snapshot.
|
||||
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
|
||||
for s in snapshots {
|
||||
if !s.health_ok || !s.ready_ok {
|
||||
items.push(DriftItem {
|
||||
kind: DriftKind::Unhealthy,
|
||||
service: s.name.clone(),
|
||||
details: serde_json::json!({
|
||||
"health_ok": s.health_ok,
|
||||
"ready_ok": s.ready_ok,
|
||||
"metrics_ok": s.metrics_ok,
|
||||
"base_url": s.base_url,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Version drift: compare build_info between services when present.
|
||||
// Desired is not yet explicit; for now we flag when multiple versions exist for same service.
|
||||
let mut versions_by_service: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
|
||||
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
|
||||
for s in snapshots {
|
||||
if let Ok(metrics) = state
|
||||
.http
|
||||
.get(format!("{}/metrics", s.base_url))
|
||||
.send()
|
||||
.await
|
||||
&& let Ok(body) = metrics.text().await
|
||||
{
|
||||
for bi in extract_build_info(&body) {
|
||||
versions_by_service
|
||||
.entry(bi.service.clone())
|
||||
.or_default()
|
||||
.insert(format!("{}@{}", bi.version, bi.git_sha));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (svc, vs) in versions_by_service {
|
||||
if vs.len() > 1 {
|
||||
items.push(DriftItem {
|
||||
kind: DriftKind::VersionMismatch,
|
||||
service: svc,
|
||||
details: serde_json::json!({ "seen": vs.into_iter().collect::<Vec<_>>() }),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn ord(k: &DriftKind) -> u8 {
|
||||
match k {
|
||||
DriftKind::Missing => 0,
|
||||
DriftKind::Extra => 1,
|
||||
DriftKind::Unhealthy => 2,
|
||||
DriftKind::VersionMismatch => 3,
|
||||
}
|
||||
}
|
||||
items.sort_by(|a, b| (ord(&a.kind), &a.service).cmp(&(ord(&b.kind), &b.service)));
|
||||
|
||||
let mut summary: BTreeMap<String, u64> = BTreeMap::new();
|
||||
for item in &items {
|
||||
let k = match item.kind {
|
||||
DriftKind::Missing => "missing",
|
||||
DriftKind::Extra => "extra",
|
||||
DriftKind::Unhealthy => "unhealthy",
|
||||
DriftKind::VersionMismatch => "version_mismatch",
|
||||
};
|
||||
*summary.entry(k.to_string()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
DriftResponse { summary, items }
|
||||
}
|
||||
@@ -1,14 +1,19 @@
|
||||
use crate::{
|
||||
AppState, Principal,
|
||||
audit::{AuditEvent, AuditStore},
|
||||
config_registry::{ConfigDomain, ConfigRegistryError},
|
||||
config_schemas::RoutingConfig,
|
||||
fleet,
|
||||
jobs::{Job, JobStatus, JobStep, JobStore},
|
||||
placement::PlacementFile,
|
||||
};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex},
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@@ -34,20 +39,52 @@ impl TenantLocks {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ConfigLocks {
|
||||
inner: Arc<Mutex<HashMap<String, Uuid>>>,
|
||||
}
|
||||
|
||||
impl ConfigLocks {
|
||||
pub fn try_lock(&self, domain: ConfigDomain, job_id: Uuid) -> bool {
|
||||
let mut map = self.inner.lock().expect("config locks poisoned");
|
||||
let k = domain.as_str().to_string();
|
||||
if map.contains_key(&k) {
|
||||
return false;
|
||||
}
|
||||
map.insert(k, job_id);
|
||||
true
|
||||
}
|
||||
|
||||
pub fn unlock(&self, domain: ConfigDomain, job_id: Uuid) {
|
||||
let mut map = self.inner.lock().expect("config locks poisoned");
|
||||
let k = domain.as_str().to_string();
|
||||
if map.get(&k).copied() == Some(job_id) {
|
||||
map.remove(&k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JobEngine {
|
||||
pub jobs: JobStore,
|
||||
pub audit: AuditStore,
|
||||
pub tenant_locks: TenantLocks,
|
||||
pub config_locks: ConfigLocks,
|
||||
pub step_timeout: Duration,
|
||||
}
|
||||
|
||||
impl JobEngine {
|
||||
pub fn new(jobs: JobStore, audit: AuditStore, tenant_locks: TenantLocks) -> Self {
|
||||
pub fn new(
|
||||
jobs: JobStore,
|
||||
audit: AuditStore,
|
||||
tenant_locks: TenantLocks,
|
||||
config_locks: ConfigLocks,
|
||||
) -> Self {
|
||||
Self {
|
||||
jobs,
|
||||
audit,
|
||||
tenant_locks,
|
||||
config_locks,
|
||||
step_timeout: Duration::from_millis(500),
|
||||
}
|
||||
}
|
||||
@@ -93,7 +130,7 @@ impl JobEngine {
|
||||
let engine = self.clone();
|
||||
tokio::spawn(async move {
|
||||
engine
|
||||
.run_job(state, inserted, Some(tenant_id), RunSpec::Drain)
|
||||
.run_job(state, inserted, Some(tenant_id), None, RunSpec::Drain)
|
||||
.await;
|
||||
});
|
||||
|
||||
@@ -152,6 +189,7 @@ impl JobEngine {
|
||||
state,
|
||||
inserted,
|
||||
Some(tenant_id),
|
||||
None,
|
||||
RunSpec::Migrate { runner_target },
|
||||
)
|
||||
.await;
|
||||
@@ -160,7 +198,238 @@ impl JobEngine {
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
async fn run_job(&self, state: AppState, job_id: Uuid, tenant_id: Option<Uuid>, spec: RunSpec) {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn start_config_apply(
|
||||
&self,
|
||||
state: AppState,
|
||||
principal: &Principal,
|
||||
domain: ConfigDomain,
|
||||
reason: String,
|
||||
expected_revision: Option<u64>,
|
||||
value: serde_json::Value,
|
||||
idempotency_key: &str,
|
||||
) -> Result<Uuid, StartJobError> {
|
||||
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let job_id = Uuid::new_v4();
|
||||
if !self.config_locks.try_lock(domain, job_id) {
|
||||
return Err(StartJobError::TenantLocked);
|
||||
}
|
||||
|
||||
let now = now_ms();
|
||||
let job = Job {
|
||||
job_id,
|
||||
status: JobStatus::Pending,
|
||||
steps: vec![
|
||||
step("preflight"),
|
||||
step("validate_config"),
|
||||
step("backup_config"),
|
||||
step("apply_config"),
|
||||
step("reload_config"),
|
||||
step("verify_config"),
|
||||
],
|
||||
error: None,
|
||||
created_at_ms: now,
|
||||
started_at_ms: None,
|
||||
finished_at_ms: None,
|
||||
};
|
||||
|
||||
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
|
||||
self.audit.record(AuditEvent {
|
||||
ts_ms: now,
|
||||
principal_sub: principal.sub.clone(),
|
||||
action: format!("config.{}.apply", domain.as_str()),
|
||||
tenant_id: None,
|
||||
reason,
|
||||
job_id: Some(inserted),
|
||||
});
|
||||
|
||||
let engine = self.clone();
|
||||
tokio::spawn(async move {
|
||||
engine
|
||||
.run_job(
|
||||
state,
|
||||
inserted,
|
||||
None,
|
||||
Some(domain),
|
||||
RunSpec::ConfigApply {
|
||||
domain,
|
||||
expected_revision,
|
||||
value,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
pub fn start_config_validate(
|
||||
&self,
|
||||
state: AppState,
|
||||
principal: &Principal,
|
||||
domain: ConfigDomain,
|
||||
reason: String,
|
||||
value: serde_json::Value,
|
||||
idempotency_key: &str,
|
||||
) -> Result<Uuid, StartJobError> {
|
||||
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let job_id = Uuid::new_v4();
|
||||
if !self.config_locks.try_lock(domain, job_id) {
|
||||
return Err(StartJobError::TenantLocked);
|
||||
}
|
||||
|
||||
let now = now_ms();
|
||||
let job = Job {
|
||||
job_id,
|
||||
status: JobStatus::Pending,
|
||||
steps: vec![step("validate_config")],
|
||||
error: None,
|
||||
created_at_ms: now,
|
||||
started_at_ms: None,
|
||||
finished_at_ms: None,
|
||||
};
|
||||
|
||||
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
|
||||
self.audit.record(AuditEvent {
|
||||
ts_ms: now,
|
||||
principal_sub: principal.sub.clone(),
|
||||
action: format!("config.{}.validate", domain.as_str()),
|
||||
tenant_id: None,
|
||||
reason,
|
||||
job_id: Some(inserted),
|
||||
});
|
||||
|
||||
let engine = self.clone();
|
||||
tokio::spawn(async move {
|
||||
engine
|
||||
.run_job(
|
||||
state,
|
||||
inserted,
|
||||
None,
|
||||
Some(domain),
|
||||
RunSpec::ConfigValidate { domain, value },
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
pub fn start_config_rollback(
|
||||
&self,
|
||||
state: AppState,
|
||||
principal: &Principal,
|
||||
domain: ConfigDomain,
|
||||
reason: String,
|
||||
idempotency_key: &str,
|
||||
) -> Result<Uuid, StartJobError> {
|
||||
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let job_id = Uuid::new_v4();
|
||||
if !self.config_locks.try_lock(domain, job_id) {
|
||||
return Err(StartJobError::TenantLocked);
|
||||
}
|
||||
|
||||
let now = now_ms();
|
||||
let job = Job {
|
||||
job_id,
|
||||
status: JobStatus::Pending,
|
||||
steps: vec![
|
||||
step("rollback_config"),
|
||||
step("reload_config"),
|
||||
step("verify_config"),
|
||||
],
|
||||
error: None,
|
||||
created_at_ms: now,
|
||||
started_at_ms: None,
|
||||
finished_at_ms: None,
|
||||
};
|
||||
|
||||
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
|
||||
self.audit.record(AuditEvent {
|
||||
ts_ms: now,
|
||||
principal_sub: principal.sub.clone(),
|
||||
action: format!("config.{}.rollback", domain.as_str()),
|
||||
tenant_id: None,
|
||||
reason,
|
||||
job_id: Some(inserted),
|
||||
});
|
||||
|
||||
let engine = self.clone();
|
||||
tokio::spawn(async move {
|
||||
engine
|
||||
.run_job(
|
||||
state,
|
||||
inserted,
|
||||
None,
|
||||
Some(domain),
|
||||
RunSpec::ConfigRollback { domain },
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
pub fn start_platform_verify(
|
||||
&self,
|
||||
state: AppState,
|
||||
principal: &Principal,
|
||||
reason: String,
|
||||
idempotency_key: &str,
|
||||
) -> Result<Uuid, StartJobError> {
|
||||
if let Some(existing) = self.jobs.get_idempotent(idempotency_key) {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let job_id = Uuid::new_v4();
|
||||
let now = now_ms();
|
||||
let job = Job {
|
||||
job_id,
|
||||
status: JobStatus::Pending,
|
||||
steps: vec![step("preflight"), step("platform_verify")],
|
||||
error: None,
|
||||
created_at_ms: now,
|
||||
started_at_ms: None,
|
||||
finished_at_ms: None,
|
||||
};
|
||||
|
||||
let inserted = self.jobs.insert_idempotent(idempotency_key, job);
|
||||
self.audit.record(AuditEvent {
|
||||
ts_ms: now,
|
||||
principal_sub: principal.sub.clone(),
|
||||
action: "platform.verify".to_string(),
|
||||
tenant_id: None,
|
||||
reason,
|
||||
job_id: Some(inserted),
|
||||
});
|
||||
|
||||
let engine = self.clone();
|
||||
tokio::spawn(async move {
|
||||
engine
|
||||
.run_job(state, inserted, None, None, RunSpec::PlatformVerify)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(inserted)
|
||||
}
|
||||
|
||||
async fn run_job(
|
||||
&self,
|
||||
state: AppState,
|
||||
job_id: Uuid,
|
||||
tenant_id: Option<Uuid>,
|
||||
config_domain: Option<ConfigDomain>,
|
||||
spec: RunSpec,
|
||||
) {
|
||||
self.jobs.update(job_id, |j| {
|
||||
j.status = JobStatus::Running;
|
||||
j.started_at_ms = Some(now_ms());
|
||||
@@ -265,6 +534,9 @@ impl JobEngine {
|
||||
if let Some(tid) = tenant_id {
|
||||
self.tenant_locks.unlock(tid, job_id);
|
||||
}
|
||||
if let Some(domain) = config_domain {
|
||||
self.config_locks.unlock(domain, job_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,7 +548,22 @@ pub enum StartJobError {
|
||||
#[derive(Clone)]
|
||||
enum RunSpec {
|
||||
Drain,
|
||||
Migrate { runner_target: String },
|
||||
Migrate {
|
||||
runner_target: String,
|
||||
},
|
||||
ConfigValidate {
|
||||
domain: ConfigDomain,
|
||||
value: serde_json::Value,
|
||||
},
|
||||
ConfigApply {
|
||||
domain: ConfigDomain,
|
||||
expected_revision: Option<u64>,
|
||||
value: serde_json::Value,
|
||||
},
|
||||
ConfigRollback {
|
||||
domain: ConfigDomain,
|
||||
},
|
||||
PlatformVerify,
|
||||
}
|
||||
|
||||
fn step(name: &str) -> JobStep {
|
||||
@@ -316,9 +603,14 @@ async fn run_step(
|
||||
"update_placement" => match spec {
|
||||
RunSpec::Migrate { runner_target } => {
|
||||
let tenant_id = tenant_id.ok_or_else(|| "missing tenant_id".to_string())?;
|
||||
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
|
||||
state
|
||||
.placement
|
||||
.update_runner_target(tenant_id, runner_target.clone())
|
||||
.update_runner_target(
|
||||
tenant_id,
|
||||
runner_target.clone(),
|
||||
entitlements.max_runners as usize,
|
||||
)
|
||||
.map(|_| ())
|
||||
}
|
||||
_ => Ok(()),
|
||||
@@ -343,6 +635,400 @@ async fn run_step(
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
"validate_config" => match spec {
|
||||
RunSpec::ConfigValidate { domain, value }
|
||||
| RunSpec::ConfigApply { domain, value, .. } => match domain {
|
||||
ConfigDomain::Routing => {
|
||||
let cfg = serde_json::from_value::<RoutingConfig>(value.clone())
|
||||
.map_err(|e| format!("invalid routing config: {e}"))?;
|
||||
validate_routing_semantic(&cfg)?;
|
||||
Ok(())
|
||||
}
|
||||
ConfigDomain::Placement => {
|
||||
let cfg = serde_json::from_value::<PlacementFile>(value.clone())
|
||||
.map_err(|e| format!("invalid placement config: {e}"))?;
|
||||
validate_placement_semantic(state, &cfg)?;
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
_ => Ok(()),
|
||||
},
|
||||
"backup_config" => match spec {
|
||||
RunSpec::ConfigApply { domain, .. } => {
|
||||
let Some(source) = state.config.source(*domain) else {
|
||||
return Err("config domain not configured".to_string());
|
||||
};
|
||||
let (cur, _) = source
|
||||
.load_bytes()
|
||||
.await
|
||||
.map_err(|e| format!("failed to load config: {e}"))?;
|
||||
let cur = cur.unwrap_or_else(|| b"null".to_vec());
|
||||
let backup_key_value = serde_json::json!({ "backup": serde_json::from_slice::<serde_json::Value>(&cur).unwrap_or(serde_json::Value::Null) });
|
||||
let bytes =
|
||||
serde_json::to_vec_pretty(&backup_key_value).map_err(|e| e.to_string())?;
|
||||
let backup_source = backup_source_for(&source.info(), *domain)
|
||||
.await
|
||||
.map_err(|e| format!("failed to build backup source: {e}"))?;
|
||||
let _ = backup_source
|
||||
.put_bytes(None, bytes)
|
||||
.await
|
||||
.map_err(|e| format!("failed to write backup: {e}"))?;
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
"apply_config" => match spec {
|
||||
RunSpec::ConfigApply {
|
||||
domain,
|
||||
expected_revision,
|
||||
value,
|
||||
} => {
|
||||
let Some(source) = state.config.source(*domain) else {
|
||||
return Err("config domain not configured".to_string());
|
||||
};
|
||||
let bytes =
|
||||
serde_json::to_vec_pretty(value).map_err(|e| format!("encode error: {e}"))?;
|
||||
let _ = source
|
||||
.put_bytes(*expected_revision, bytes)
|
||||
.await
|
||||
.map_err(|e| format!("apply failed: {e}"))?;
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
"rollback_config" => match spec {
|
||||
RunSpec::ConfigRollback { domain } => {
|
||||
let Some(source) = state.config.source(*domain) else {
|
||||
return Err("config domain not configured".to_string());
|
||||
};
|
||||
let backup_source = backup_source_for(&source.info(), *domain)
|
||||
.await
|
||||
.map_err(|e| format!("failed to build backup source: {e}"))?;
|
||||
let (bytes, _) = backup_source
|
||||
.load_bytes()
|
||||
.await
|
||||
.map_err(|e| format!("failed to load backup: {e}"))?;
|
||||
let Some(bytes) = bytes else {
|
||||
return Err("no backup available".to_string());
|
||||
};
|
||||
let v: serde_json::Value = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| format!("invalid backup json: {e}"))?;
|
||||
let backup = v.get("backup").cloned().unwrap_or(serde_json::Value::Null);
|
||||
let next =
|
||||
serde_json::to_vec_pretty(&backup).map_err(|e| format!("encode error: {e}"))?;
|
||||
let _ = source
|
||||
.put_bytes(None, next)
|
||||
.await
|
||||
.map_err(|e| format!("rollback failed: {e}"))?;
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
"reload_config" => Ok(()),
|
||||
"verify_config" => match spec {
|
||||
RunSpec::ConfigValidate { domain, .. }
|
||||
| RunSpec::ConfigApply { domain, .. }
|
||||
| RunSpec::ConfigRollback { domain } => {
|
||||
let Some(source) = state.config.source(*domain) else {
|
||||
return Err("config domain not configured".to_string());
|
||||
};
|
||||
let (bytes, _) = source
|
||||
.load_bytes()
|
||||
.await
|
||||
.map_err(|e| format!("failed to load config: {e}"))?;
|
||||
let bytes = bytes.unwrap_or_else(|| b"null".to_vec());
|
||||
let v: serde_json::Value = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| format!("invalid stored json: {e}"))?;
|
||||
match domain {
|
||||
ConfigDomain::Routing => {
|
||||
let cfg = serde_json::from_value::<RoutingConfig>(v)
|
||||
.map_err(|e| format!("invalid routing config: {e}"))?;
|
||||
validate_routing_semantic(&cfg)?;
|
||||
Ok(())
|
||||
}
|
||||
ConfigDomain::Placement => {
|
||||
let cfg = serde_json::from_value::<PlacementFile>(v)
|
||||
.map_err(|e| format!("invalid placement config: {e}"))?;
|
||||
validate_placement_semantic(state, &cfg)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
"platform_verify" => match spec {
|
||||
RunSpec::PlatformVerify => {
|
||||
let snapshots = fleet::snapshot(&state.http, &state.fleet_services).await;
|
||||
let bad: Vec<_> = snapshots
|
||||
.into_iter()
|
||||
.filter(|s| !(s.health_ok && s.ready_ok))
|
||||
.map(|s| {
|
||||
format!(
|
||||
"{} health_ok={} ready_ok={}",
|
||||
s.name, s.health_ok, s.ready_ok
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
if !bad.is_empty() {
|
||||
return Err(format!("platform verify failed: {}", bad.join("; ")));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn backup_source_for(
|
||||
info: &crate::config_registry::ConfigSourceInfo,
|
||||
domain: ConfigDomain,
|
||||
) -> Result<Arc<dyn crate::config_registry::ConfigSource>, ConfigRegistryError> {
|
||||
use crate::config_registry::{ConfigSource, FileSource, NatsKvSource};
|
||||
match info {
|
||||
crate::config_registry::ConfigSourceInfo::File { path } => Ok(Arc::new(FileSource::new(
|
||||
PathBuf::from(path).with_extension(format!("{}.bak.json", domain.as_str())),
|
||||
))
|
||||
as Arc<dyn ConfigSource>),
|
||||
crate::config_registry::ConfigSourceInfo::NatsKv { bucket, key } => {
|
||||
let nats_url = std::env::var("CONTROL_CONFIG_NATS_URL").map_err(|_| {
|
||||
ConfigRegistryError::Source("missing CONTROL_CONFIG_NATS_URL".to_string())
|
||||
})?;
|
||||
Ok(Arc::new(
|
||||
NatsKvSource::connect(nats_url, bucket.clone(), format!("{key}.bak"))
|
||||
.await
|
||||
.map_err(|e| ConfigRegistryError::Source(e.to_string()))?,
|
||||
) as Arc<dyn ConfigSource>)
|
||||
}
|
||||
crate::config_registry::ConfigSourceInfo::Fixed => Err(ConfigRegistryError::Source(
|
||||
"no backups for fixed source".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_routing_semantic(cfg: &RoutingConfig) -> Result<(), String> {
|
||||
let shard_maps = [
|
||||
("aggregate_shards", &cfg.aggregate_shards),
|
||||
("projection_shards", &cfg.projection_shards),
|
||||
("runner_shards", &cfg.runner_shards),
|
||||
];
|
||||
for (name, map) in shard_maps {
|
||||
for (shard_id, endpoints) in map {
|
||||
if endpoints.is_empty() {
|
||||
return Err(format!("{name}[{shard_id}] has no endpoints"));
|
||||
}
|
||||
for ep in endpoints {
|
||||
let u = Url::parse(ep)
|
||||
.map_err(|e| format!("{name}[{shard_id}] invalid endpoint {ep:?}: {e}"))?;
|
||||
if u.scheme() != "http" && u.scheme() != "https" {
|
||||
return Err(format!(
|
||||
"{name}[{shard_id}] endpoint {ep:?} must be http(s)"
|
||||
));
|
||||
}
|
||||
if u.host_str().is_none() {
|
||||
return Err(format!(
|
||||
"{name}[{shard_id}] endpoint {ep:?} must include host"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure placement references known shard ids.
|
||||
let placements = [
|
||||
(
|
||||
"aggregate_placement",
|
||||
&cfg.aggregate_placement,
|
||||
&cfg.aggregate_shards,
|
||||
),
|
||||
(
|
||||
"projection_placement",
|
||||
&cfg.projection_placement,
|
||||
&cfg.projection_shards,
|
||||
),
|
||||
(
|
||||
"runner_placement",
|
||||
&cfg.runner_placement,
|
||||
&cfg.runner_shards,
|
||||
),
|
||||
];
|
||||
for (pname, pmap, shards) in placements {
|
||||
for (tenant, shard_id) in pmap {
|
||||
if shard_id.trim().is_empty() {
|
||||
return Err(format!("{pname}[{tenant}] shard_id is empty"));
|
||||
}
|
||||
if !shards.contains_key(shard_id) {
|
||||
return Err(format!(
|
||||
"{pname}[{tenant}] references missing shard_id {shard_id:?}"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_placement_semantic(state: &AppState, cfg: &PlacementFile) -> Result<(), String> {
|
||||
if !state.billing_enforcement_enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut tenant_counts = std::collections::HashMap::new();
|
||||
|
||||
let kinds = [
|
||||
("aggregate_placement", cfg.aggregate_placement.as_ref()),
|
||||
("projection_placement", cfg.projection_placement.as_ref()),
|
||||
("runner_placement", cfg.runner_placement.as_ref()),
|
||||
];
|
||||
for (kind_name, k) in kinds {
|
||||
let Some(k) = k else { continue };
|
||||
for p in &k.placements {
|
||||
if p.targets.is_empty() {
|
||||
return Err(format!("{kind_name} tenant {} has no targets", p.tenant_id));
|
||||
}
|
||||
if p.targets.iter().any(|t| t.trim().is_empty()) {
|
||||
return Err(format!(
|
||||
"{kind_name} tenant {} has empty target",
|
||||
p.tenant_id
|
||||
));
|
||||
}
|
||||
|
||||
let entry = tenant_counts.entry(p.tenant_id).or_insert((0, 0)); // (deployments, runners)
|
||||
if kind_name == "runner_placement" {
|
||||
entry.1 += p.targets.len();
|
||||
} else {
|
||||
entry.0 += p.targets.len();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (tenant_id, (deployments, runners)) in tenant_counts {
|
||||
let entitlements = state.billing.get_for_tenant(tenant_id).entitlements;
|
||||
if deployments > entitlements.max_deployments as usize {
|
||||
return Err(format!(
|
||||
"tenant {} exceeds max_deployments limit ({} > {})",
|
||||
tenant_id, deployments, entitlements.max_deployments
|
||||
));
|
||||
}
|
||||
if runners > entitlements.max_runners as usize {
|
||||
return Err(format!(
|
||||
"tenant {} exceeds max_runners limit ({} > {})",
|
||||
tenant_id, runners, entitlements.max_runners
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::billing::{BillingStore, Plan, SubscriptionStatus, TenantBillingState};
|
||||
use crate::placement::{PlacementFile, PlacementKind, TenantPlacement};
|
||||
|
||||
fn mock_state(billing: BillingStore) -> AppState {
|
||||
let handle = crate::get_test_prometheus_handle();
|
||||
let root = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
AppState {
|
||||
prometheus: handle,
|
||||
auth: crate::AuthConfig {
|
||||
hs256_secret: Some(b"secret".to_vec()),
|
||||
},
|
||||
jobs: JobStore::default(),
|
||||
audit: AuditStore::default(),
|
||||
tenant_locks: TenantLocks::default(),
|
||||
config_locks: ConfigLocks::default(),
|
||||
http: reqwest::Client::new(),
|
||||
placement: crate::placement::PlacementStore::new(
|
||||
std::env::temp_dir().join("placement.json"),
|
||||
),
|
||||
billing,
|
||||
billing_provider: Arc::new(crate::billing::MockProvider),
|
||||
billing_enforcement_enabled: true,
|
||||
config: crate::config_registry::ConfigRegistry::new(None, None),
|
||||
fleet_services: vec![],
|
||||
swarm: crate::swarm::SwarmStore::new(root.join("swarm/dev.json")),
|
||||
docs: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_placement_limits() {
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let billing_path =
|
||||
std::env::temp_dir().join(format!("billing-unit-{}.json", Uuid::new_v4()));
|
||||
let billing = BillingStore::new(billing_path.clone());
|
||||
|
||||
let state = mock_state(billing.clone());
|
||||
|
||||
// 1. Free plan (default): max_deployments=1, max_runners=1
|
||||
let cfg = PlacementFile {
|
||||
revision: Some("v1".to_string()),
|
||||
aggregate_placement: Some(PlacementKind {
|
||||
placements: vec![TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec!["a1".to_string()],
|
||||
}],
|
||||
}),
|
||||
projection_placement: Some(PlacementKind {
|
||||
placements: vec![TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec!["p1".to_string()],
|
||||
}],
|
||||
}),
|
||||
runner_placement: Some(PlacementKind {
|
||||
placements: vec![TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec!["r1".to_string()],
|
||||
}],
|
||||
}),
|
||||
};
|
||||
|
||||
// aggregate(1) + projection(1) = 2 deployments. Limit is 1. Should fail.
|
||||
let err = validate_placement_semantic(&state, &cfg).unwrap_err();
|
||||
assert!(err.contains("exceeds max_deployments limit"));
|
||||
|
||||
// 2. Reduce to 1 deployment
|
||||
let cfg2 = PlacementFile {
|
||||
revision: Some("v2".to_string()),
|
||||
aggregate_placement: Some(PlacementKind {
|
||||
placements: vec![TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec!["a1".to_string()],
|
||||
}],
|
||||
}),
|
||||
projection_placement: None,
|
||||
runner_placement: Some(PlacementKind {
|
||||
placements: vec![TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec!["r1".to_string()],
|
||||
}],
|
||||
}),
|
||||
};
|
||||
validate_placement_semantic(&state, &cfg2).unwrap();
|
||||
|
||||
// 3. Upgrade to Pro: max_deployments=10, max_runners=10
|
||||
billing
|
||||
.update_tenant_state(
|
||||
tenant_id,
|
||||
TenantBillingState {
|
||||
provider: "mock".to_string(),
|
||||
provider_customer_id: None,
|
||||
provider_subscription_id: None,
|
||||
provider_checkout_session_id: None,
|
||||
status: Some(SubscriptionStatus::Active),
|
||||
plan: Some(Plan::Pro),
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: None,
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 100,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Now the first cfg should pass
|
||||
validate_placement_semantic(&state, &cfg).unwrap();
|
||||
|
||||
let _ = std::fs::remove_file(billing_path);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
mod admin;
|
||||
mod audit;
|
||||
mod auth;
|
||||
pub mod billing;
|
||||
mod build_info;
|
||||
pub mod config_registry;
|
||||
mod config_schemas;
|
||||
mod deployments;
|
||||
mod documents;
|
||||
mod drift;
|
||||
mod fleet;
|
||||
mod job_engine;
|
||||
mod jobs;
|
||||
mod placement;
|
||||
pub mod s3_docs;
|
||||
mod swarm;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use audit::AuditStore;
|
||||
pub use auth::{AuthConfig, Principal};
|
||||
use axum::{
|
||||
@@ -20,8 +28,10 @@ use axum::{
|
||||
routing::get,
|
||||
};
|
||||
pub use build_info::{BuildInfo, extract_build_info};
|
||||
pub use config_registry::{ConfigDomain, ConfigRegistry};
|
||||
pub use deployments::{DeployAnnotationArgs, GrafanaAnnotation, build_grafana_deploy_annotation};
|
||||
pub use fleet::FleetService;
|
||||
pub use job_engine::ConfigLocks;
|
||||
pub use job_engine::TenantLocks;
|
||||
pub use jobs::JobStore;
|
||||
use metrics_exporter_prometheus::PrometheusHandle;
|
||||
@@ -40,10 +50,16 @@ pub struct AppState {
|
||||
pub jobs: JobStore,
|
||||
pub audit: AuditStore,
|
||||
pub tenant_locks: TenantLocks,
|
||||
pub config_locks: ConfigLocks,
|
||||
pub http: reqwest::Client,
|
||||
pub placement: PlacementStore,
|
||||
pub billing: billing::BillingStore,
|
||||
pub billing_provider: Arc<dyn billing::BillingProvider>,
|
||||
pub billing_enforcement_enabled: bool,
|
||||
pub config: ConfigRegistry,
|
||||
pub fleet_services: Vec<FleetService>,
|
||||
pub swarm: SwarmStore,
|
||||
pub docs: Option<s3_docs::DocsStore>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -93,13 +109,18 @@ pub fn build_app(state: AppState) -> Router {
|
||||
},
|
||||
);
|
||||
|
||||
let admin =
|
||||
admin::admin_router().layer(from_fn_with_state(state.clone(), auth::auth_middleware));
|
||||
let admin = admin::admin_router()
|
||||
.merge(documents::router())
|
||||
.layer(from_fn_with_state(state.clone(), auth::auth_middleware));
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/ready", get(ready))
|
||||
.route("/metrics", get(metrics))
|
||||
.route(
|
||||
"/admin/v1/billing/webhooks/{provider}",
|
||||
axum::routing::post(billing::webhook),
|
||||
)
|
||||
.nest("/admin/v1", admin)
|
||||
.with_state(state)
|
||||
.layer(trace)
|
||||
@@ -167,25 +188,46 @@ async fn request_id_middleware(mut req: Request<axum::body::Body>, next: Next) -
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static TEST_PROMETHEUS_HANDLE: std::sync::OnceLock<PrometheusHandle> = std::sync::OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn get_test_prometheus_handle() -> PrometheusHandle {
|
||||
TEST_PROMETHEUS_HANDLE
|
||||
.get_or_init(|| {
|
||||
metrics_exporter_prometheus::PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.unwrap_or_else(|_| {
|
||||
// This can happen if another test already installed it.
|
||||
// We might not get the ACTUAL handle to the global recorder here if we don't share it,
|
||||
// but for tests it's usually fine to have a dummy one if we are not asserting on metrics.
|
||||
metrics_exporter_prometheus::PrometheusBuilder::new()
|
||||
.build()
|
||||
.expect("failed to build prometheus recorder")
|
||||
.0
|
||||
.handle()
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config_registry::{FileSource, FixedSource};
|
||||
use crate::jobs::JobStatus;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode, header},
|
||||
};
|
||||
use jsonwebtoken::{EncodingKey, Header, encode};
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use serde::Serialize;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TestClaims {
|
||||
sub: String,
|
||||
@@ -199,15 +241,10 @@ mod tests {
|
||||
}
|
||||
|
||||
fn test_app_with_fleet(fleet_services: Vec<FleetService>) -> Router {
|
||||
let handle = HANDLE
|
||||
.get_or_init(|| {
|
||||
PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.expect("failed to install prometheus recorder")
|
||||
})
|
||||
.clone();
|
||||
let handle = get_test_prometheus_handle();
|
||||
|
||||
let placement_path = temp_placement_file();
|
||||
let root = repo_root();
|
||||
|
||||
build_app(AppState {
|
||||
prometheus: handle,
|
||||
@@ -217,10 +254,23 @@ mod tests {
|
||||
jobs: JobStore::default(),
|
||||
audit: AuditStore::default(),
|
||||
tenant_locks: TenantLocks::default(),
|
||||
config_locks: ConfigLocks::default(),
|
||||
http: reqwest::Client::new(),
|
||||
placement: PlacementStore::new(placement_path),
|
||||
billing: crate::billing::BillingStore::new(
|
||||
std::env::temp_dir().join(format!("billing-test-{}.json", Uuid::new_v4())),
|
||||
),
|
||||
billing_provider: Arc::new(crate::billing::MockProvider),
|
||||
billing_enforcement_enabled: true,
|
||||
config: ConfigRegistry::new(
|
||||
Some(Arc::new(FileSource::new(
|
||||
root.join("config/routing/dev.json"),
|
||||
))),
|
||||
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
|
||||
),
|
||||
fleet_services,
|
||||
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
|
||||
docs: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -234,14 +284,14 @@ mod tests {
|
||||
|
||||
fn temp_placement_file() -> PathBuf {
|
||||
let root = repo_root();
|
||||
let src = root.join("placement/dev.json");
|
||||
let src = root.join("config/placement/dev.json");
|
||||
let mut dst = std::env::temp_dir();
|
||||
dst.push(format!(
|
||||
"cloudlysis-control-placement-{}-{}.json",
|
||||
std::process::id(),
|
||||
Uuid::new_v4()
|
||||
));
|
||||
let raw = fs::read_to_string(src).expect("missing placement/dev.json");
|
||||
let raw = fs::read_to_string(src).expect("missing config/placement/dev.json");
|
||||
fs::write(&dst, raw).expect("failed to write temp placement file");
|
||||
dst
|
||||
}
|
||||
@@ -689,4 +739,467 @@ mod tests {
|
||||
&serde_json::json!(["preflight", "drain", "update_placement", "reload", "verify"])
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn billing_returns_not_configured_by_default() {
|
||||
let token = make_token(&["control:read"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let res = test_app()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(false));
|
||||
assert_eq!(
|
||||
v.get("entitlements")
|
||||
.unwrap()
|
||||
.get("max_deployments")
|
||||
.unwrap(),
|
||||
&serde_json::json!(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn billing_returns_configured_state() {
|
||||
let token = make_token(&["control:read"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
|
||||
let handle = get_test_prometheus_handle();
|
||||
|
||||
let billing_path =
|
||||
std::env::temp_dir().join(format!("billing-test-cfg-{}.json", Uuid::new_v4()));
|
||||
let billing = crate::billing::BillingStore::new(billing_path.clone());
|
||||
|
||||
billing
|
||||
.update_tenant_state(
|
||||
tenant_id,
|
||||
crate::billing::TenantBillingState {
|
||||
provider: "stripe".to_string(),
|
||||
provider_customer_id: Some("cus_123".to_string()),
|
||||
provider_subscription_id: Some("sub_123".to_string()),
|
||||
provider_checkout_session_id: None,
|
||||
status: Some(crate::billing::SubscriptionStatus::Active),
|
||||
plan: Some(crate::billing::Plan::Pro),
|
||||
current_period_end: Some("2026-04-30T00:00:00Z".to_string()),
|
||||
cancel_at_period_end: Some(false),
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 1234567890,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let root = repo_root();
|
||||
let app = build_app(AppState {
|
||||
prometheus: handle,
|
||||
auth: AuthConfig {
|
||||
hs256_secret: Some(b"test_secret".to_vec()),
|
||||
},
|
||||
jobs: JobStore::default(),
|
||||
audit: AuditStore::default(),
|
||||
tenant_locks: TenantLocks::default(),
|
||||
config_locks: ConfigLocks::default(),
|
||||
http: reqwest::Client::new(),
|
||||
placement: PlacementStore::new(temp_placement_file()),
|
||||
billing,
|
||||
billing_provider: Arc::new(crate::billing::MockProvider),
|
||||
billing_enforcement_enabled: true,
|
||||
config: ConfigRegistry::new(
|
||||
Some(Arc::new(FileSource::new(
|
||||
root.join("config/routing/dev.json"),
|
||||
))),
|
||||
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
|
||||
),
|
||||
fleet_services: vec![],
|
||||
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
|
||||
docs: None,
|
||||
});
|
||||
|
||||
let res = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
|
||||
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
|
||||
assert_eq!(
|
||||
v.get("entitlements")
|
||||
.unwrap()
|
||||
.get("max_deployments")
|
||||
.unwrap(),
|
||||
&serde_json::json!(10)
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_file(billing_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn checkout_returns_mock_url() {
|
||||
let token = make_token(&["control:write"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let body = serde_json::json!({
|
||||
"plan": "pro",
|
||||
"return_path": "/custom-return"
|
||||
});
|
||||
|
||||
let res = test_app()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
|
||||
.method("POST")
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(
|
||||
v.get("url").unwrap(),
|
||||
&serde_json::json!(format!("https://mock.stripe.com/checkout/{}", tenant_id))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn checkout_fails_if_already_active() {
|
||||
let token = make_token(&["control:write"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
|
||||
// Setup app with active subscription
|
||||
let billing_path =
|
||||
std::env::temp_dir().join(format!("billing-test-active-{}.json", Uuid::new_v4()));
|
||||
let billing = crate::billing::BillingStore::new(billing_path.clone());
|
||||
billing
|
||||
.update_tenant_state(
|
||||
tenant_id,
|
||||
crate::billing::TenantBillingState {
|
||||
provider: "mock".to_string(),
|
||||
provider_customer_id: None,
|
||||
provider_subscription_id: None,
|
||||
provider_checkout_session_id: None,
|
||||
status: Some(crate::billing::SubscriptionStatus::Active),
|
||||
plan: Some(crate::billing::Plan::Pro),
|
||||
current_period_end: None,
|
||||
cancel_at_period_end: None,
|
||||
processed_webhook_event_ids: vec![],
|
||||
updated_at: 0,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let handle = get_test_prometheus_handle();
|
||||
let root = repo_root();
|
||||
let app = build_app(AppState {
|
||||
prometheus: handle,
|
||||
auth: AuthConfig {
|
||||
hs256_secret: Some(b"test_secret".to_vec()),
|
||||
},
|
||||
jobs: JobStore::default(),
|
||||
audit: AuditStore::default(),
|
||||
tenant_locks: TenantLocks::default(),
|
||||
config_locks: ConfigLocks::default(),
|
||||
http: reqwest::Client::new(),
|
||||
placement: PlacementStore::new(temp_placement_file()),
|
||||
billing,
|
||||
billing_provider: Arc::new(crate::billing::MockProvider),
|
||||
billing_enforcement_enabled: true,
|
||||
config: ConfigRegistry::new(
|
||||
Some(Arc::new(FileSource::new(
|
||||
root.join("config/routing/dev.json"),
|
||||
))),
|
||||
Some(Arc::new(FixedSource::new(b"{}".to_vec()))),
|
||||
),
|
||||
fleet_services: vec![],
|
||||
swarm: SwarmStore::new(repo_root().join("swarm/dev.json")),
|
||||
docs: None,
|
||||
});
|
||||
|
||||
let body = serde_json::json!({ "plan": "pro" });
|
||||
let res = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/checkout"))
|
||||
.method("POST")
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(body.to_string()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::CONFLICT);
|
||||
let _ = std::fs::remove_file(billing_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn portal_returns_mock_url() {
|
||||
let token = make_token(&["control:write"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
|
||||
let res = test_app()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing/portal"))
|
||||
.method("POST")
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(
|
||||
v.get("url").unwrap(),
|
||||
&serde_json::json!(format!("https://mock.stripe.com/portal/{}", tenant_id))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_updates_state_idempotently() {
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let event_id = "evt_123".to_string();
|
||||
|
||||
let app = test_app();
|
||||
|
||||
let event = crate::billing::BillingEvent::SubscriptionCreated {
|
||||
tenant_id,
|
||||
event_id: event_id.clone(),
|
||||
provider_customer_id: "cus_123".to_string(),
|
||||
provider_subscription_id: "sub_123".to_string(),
|
||||
status: crate::billing::SubscriptionStatus::Active,
|
||||
plan: crate::billing::Plan::Pro,
|
||||
current_period_end: "2026-04-30T00:00:00Z".to_string(),
|
||||
ts_ms: 1000,
|
||||
};
|
||||
|
||||
let body = serde_json::to_string(&event).unwrap();
|
||||
|
||||
// 1. Send webhook
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin/v1/billing/webhooks/mock")
|
||||
.method("POST")
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(body.clone()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
// 2. Verify state
|
||||
let token = make_token(&["control:read"]);
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert_eq!(v.get("configured").unwrap(), &serde_json::json!(true));
|
||||
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("pro"));
|
||||
|
||||
// 3. Send same webhook again (idempotency)
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin/v1/billing/webhooks/mock")
|
||||
.method("POST")
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(body))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_ignores_stale_events() {
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let app = test_app();
|
||||
|
||||
// 1. Send recent event (ts=2000)
|
||||
let event1 = crate::billing::BillingEvent::SubscriptionUpdated {
|
||||
tenant_id,
|
||||
event_id: "evt_new".to_string(),
|
||||
status: crate::billing::SubscriptionStatus::Active,
|
||||
plan: crate::billing::Plan::Enterprise,
|
||||
current_period_end: "2026-05-30T00:00:00Z".to_string(),
|
||||
cancel_at_period_end: false,
|
||||
ts_ms: 2000,
|
||||
};
|
||||
|
||||
app.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin/v1/billing/webhooks/mock")
|
||||
.method("POST")
|
||||
.body(Body::from(serde_json::to_string(&event1).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 2. Send stale event (ts=1000)
|
||||
let event2 = crate::billing::BillingEvent::SubscriptionUpdated {
|
||||
tenant_id,
|
||||
event_id: "evt_old".to_string(),
|
||||
status: crate::billing::SubscriptionStatus::PastDue,
|
||||
plan: crate::billing::Plan::Pro,
|
||||
current_period_end: "2026-04-30T00:00:00Z".to_string(),
|
||||
cancel_at_period_end: false,
|
||||
ts_ms: 1000,
|
||||
};
|
||||
|
||||
app.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin/v1/billing/webhooks/mock")
|
||||
.method("POST")
|
||||
.body(Body::from(serde_json::to_string(&event2).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 3. Verify state is still Enterprise
|
||||
let token = make_token(&["control:read"]);
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/billing"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body_bytes = axum::body::to_bytes(res.into_body(), 1024 * 1024)
|
||||
.await
|
||||
.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
|
||||
assert_eq!(v.get("plan").unwrap(), &serde_json::json!("enterprise"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn s3_docs_requires_pro_plan() {
|
||||
let token = make_token(&["control:read", "control:write"]);
|
||||
let tenant_id = Uuid::new_v4();
|
||||
let app = test_app();
|
||||
|
||||
// 1. Try to list docs (Free plan by default)
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(res.status(), StatusCode::PAYMENT_REQUIRED);
|
||||
|
||||
// 2. Update to Pro plan via webhook
|
||||
let event = crate::billing::BillingEvent::SubscriptionCreated {
|
||||
tenant_id,
|
||||
event_id: "evt_pro".to_string(),
|
||||
provider_customer_id: "cus_pro".to_string(),
|
||||
provider_subscription_id: "sub_pro".to_string(),
|
||||
status: crate::billing::SubscriptionStatus::Active,
|
||||
plan: crate::billing::Plan::Pro,
|
||||
current_period_end: "2099-01-01T00:00:00Z".to_string(),
|
||||
ts_ms: 2000,
|
||||
};
|
||||
app.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/admin/v1/billing/webhooks/mock")
|
||||
.method("POST")
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&event).unwrap()))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// 3. Try to list docs again (Should fail with 503 if S3 not configured in tests, or 200/502 if it is)
|
||||
// In test_app(), docs is None by default.
|
||||
let res = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri(format!("/admin/v1/tenants/{tenant_id}/docs"))
|
||||
.header(header::AUTHORIZATION, format!("Bearer {token}"))
|
||||
.header("x-tenant-id", tenant_id.to_string())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Since docs is None in test_app(), it returns SERVICE_UNAVAILABLE (503) AFTER passing the entitlement check.
|
||||
// If it was still PAYMENT_REQUIRED, it would return 402.
|
||||
assert_eq!(res.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use clap::Parser;
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -33,16 +35,32 @@ async fn main() {
|
||||
.build()
|
||||
.expect("failed to build http client");
|
||||
|
||||
let placement_path = std::env::var("CONTROL_PLACEMENT_PATH")
|
||||
let placement_path: PathBuf = std::env::var("CONTROL_PLACEMENT_PATH")
|
||||
.ok()
|
||||
.unwrap_or_else(|| "placement/dev.json".to_string())
|
||||
.unwrap_or_else(|| "config/placement/dev.json".to_string())
|
||||
.into();
|
||||
|
||||
let swarm_path = std::env::var("CONTROL_SWARM_STATE_PATH")
|
||||
let billing_path: PathBuf = std::env::var("CONTROL_BILLING_STATE_PATH")
|
||||
.ok()
|
||||
.unwrap_or_else(|| "swarm/dev.json".to_string())
|
||||
.unwrap_or_else(|| "billing/dev.json".to_string())
|
||||
.into();
|
||||
|
||||
let routing_path: PathBuf = std::env::var("CONTROL_ROUTING_PATH")
|
||||
.ok()
|
||||
.unwrap_or_else(|| "config/routing/dev.json".to_string())
|
||||
.into();
|
||||
|
||||
let swarm_mode = std::env::var("CONTROL_SWARM_MODE").ok();
|
||||
let swarm = if swarm_mode.as_deref() == Some("docker") {
|
||||
api::SwarmStore::new_docker_cli()
|
||||
} else {
|
||||
let swarm_path: PathBuf = std::env::var("CONTROL_SWARM_STATE_PATH")
|
||||
.ok()
|
||||
.unwrap_or_else(|| "swarm/dev.json".to_string())
|
||||
.into();
|
||||
api::SwarmStore::new(swarm_path)
|
||||
};
|
||||
|
||||
let self_url = std::env::var("CONTROL_SELF_URL")
|
||||
.ok()
|
||||
.unwrap_or_else(|| "http://127.0.0.1:8080".to_string());
|
||||
@@ -55,7 +73,70 @@ async fn main() {
|
||||
fleet_services.extend(parse_fleet_services(&spec));
|
||||
}
|
||||
|
||||
let app = api::build_app(api::AppState {
|
||||
let docs_cfg =
|
||||
api::s3_docs::DocsConfig::from_env().expect("missing S3 document storage configuration");
|
||||
let docs = api::s3_docs::DocsStore::new(docs_cfg)
|
||||
.await
|
||||
.expect("failed to initialize S3 document storage client");
|
||||
|
||||
let config = {
|
||||
let routing = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
|
||||
std::env::var("CONTROL_ROUTING_NATS_URL"),
|
||||
std::env::var("CONTROL_ROUTING_NATS_BUCKET"),
|
||||
std::env::var("CONTROL_ROUTING_NATS_KEY"),
|
||||
) {
|
||||
Some(Arc::new(
|
||||
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
|
||||
.await
|
||||
.expect("failed to connect to routing config nats kv"),
|
||||
) as Arc<dyn api::config_registry::ConfigSource>)
|
||||
} else {
|
||||
Some(
|
||||
Arc::new(api::config_registry::FileSource::new(routing_path))
|
||||
as Arc<dyn api::config_registry::ConfigSource>,
|
||||
)
|
||||
};
|
||||
|
||||
let placement = if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
|
||||
std::env::var("CONTROL_PLACEMENT_NATS_URL"),
|
||||
std::env::var("CONTROL_PLACEMENT_NATS_BUCKET"),
|
||||
std::env::var("CONTROL_PLACEMENT_NATS_KEY"),
|
||||
) {
|
||||
Some(Arc::new(
|
||||
api::config_registry::NatsKvSource::connect(nats_url, bucket, key)
|
||||
.await
|
||||
.expect("failed to connect to placement config nats kv"),
|
||||
) as Arc<dyn api::config_registry::ConfigSource>)
|
||||
} else {
|
||||
Some(Arc::new(api::config_registry::FileSource::new(
|
||||
placement_path.clone(),
|
||||
))
|
||||
as Arc<dyn api::config_registry::ConfigSource>)
|
||||
};
|
||||
|
||||
api::ConfigRegistry::new(routing, placement)
|
||||
};
|
||||
|
||||
let billing_provider: Arc<dyn api::billing::BillingProvider> =
|
||||
match std::env::var("CONTROL_BILLING_PROVIDER").as_deref() {
|
||||
Ok("stripe") => {
|
||||
let secret_key = std::env::var("CONTROL_STRIPE_SECRET_KEY")
|
||||
.expect("CONTROL_STRIPE_SECRET_KEY required for stripe provider");
|
||||
let price_pro = std::env::var("CONTROL_STRIPE_PRICE_ID_PRO")
|
||||
.expect("CONTROL_STRIPE_PRICE_ID_PRO required for stripe provider");
|
||||
let price_enterprise = std::env::var("CONTROL_STRIPE_PRICE_ID_ENTERPRISE")
|
||||
.expect("CONTROL_STRIPE_PRICE_ID_ENTERPRISE required for stripe provider");
|
||||
|
||||
Arc::new(api::billing::StripeProvider {
|
||||
secret_key,
|
||||
price_pro,
|
||||
price_enterprise,
|
||||
})
|
||||
}
|
||||
_ => Arc::new(api::billing::MockProvider),
|
||||
};
|
||||
|
||||
let state = api::AppState {
|
||||
prometheus: recorder,
|
||||
auth: api::AuthConfig {
|
||||
hs256_secret: std::env::var("CONTROL_GATEWAY_JWT_HS256_SECRET")
|
||||
@@ -65,11 +146,25 @@ async fn main() {
|
||||
jobs: api::JobStore::default(),
|
||||
audit: api::AuditStore::default(),
|
||||
tenant_locks: api::TenantLocks::default(),
|
||||
config_locks: api::ConfigLocks::default(),
|
||||
http,
|
||||
placement: api::PlacementStore::new(placement_path),
|
||||
billing: api::billing::BillingStore::new(billing_path),
|
||||
billing_provider,
|
||||
billing_enforcement_enabled: std::env::var("CONTROL_BILLING_ENFORCEMENT_ENABLED")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(false),
|
||||
config,
|
||||
fleet_services,
|
||||
swarm: api::SwarmStore::new(swarm_path),
|
||||
});
|
||||
swarm,
|
||||
docs: Some(docs),
|
||||
};
|
||||
|
||||
// Spawn reconciliation loop
|
||||
tokio::spawn(api::billing::run_reconciliation_loop(state.clone()));
|
||||
|
||||
let app = api::build_app(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(args.addr)
|
||||
.await
|
||||
|
||||
@@ -157,6 +157,7 @@ impl PlacementStore {
|
||||
&self,
|
||||
tenant_id: Uuid,
|
||||
runner_target: String,
|
||||
max_runners: usize,
|
||||
) -> Result<String, String> {
|
||||
let mut inner = self.inner.write().expect("placement lock poisoned");
|
||||
inner.reload_if_changed();
|
||||
@@ -178,8 +179,17 @@ impl PlacementStore {
|
||||
.iter_mut()
|
||||
.find(|p| p.tenant_id == tenant_id)
|
||||
{
|
||||
// If already at or above limit, and we are adding a NEW target (not replacing), it would fail.
|
||||
// But here update_runner_target REPLACES the target list with a single target for now.
|
||||
// If in the future we want to append, we check targets.len().
|
||||
if 1 > max_runners {
|
||||
return Err(format!("exceeds max_runners limit of {}", max_runners));
|
||||
}
|
||||
existing.targets = vec![runner_target];
|
||||
} else {
|
||||
if 1 > max_runners {
|
||||
return Err(format!("exceeds max_runners limit of {}", max_runners));
|
||||
}
|
||||
runner.placements.push(TenantPlacement {
|
||||
tenant_id,
|
||||
targets: vec![runner_target],
|
||||
|
||||
508
control/api/src/s3_docs.rs
Normal file
508
control/api/src/s3_docs.rs
Normal file
@@ -0,0 +1,508 @@
|
||||
use aws_config::Region;
|
||||
use aws_credential_types::Credentials;
|
||||
use aws_sdk_s3::presigning::PresigningConfig;
|
||||
use aws_sdk_s3::types::BucketCannedAcl;
|
||||
use aws_sdk_s3::{Client, config::Builder as S3ConfigBuilder};
|
||||
use sha2::Digest;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DocsConfig {
|
||||
pub endpoint: String,
|
||||
pub public_endpoint: Option<String>,
|
||||
pub region: String,
|
||||
pub access_key_id: String,
|
||||
pub secret_access_key: String,
|
||||
pub force_path_style: bool,
|
||||
pub insecure: bool,
|
||||
pub buckets: Vec<String>,
|
||||
pub prefix: String,
|
||||
}
|
||||
|
||||
impl DocsConfig {
|
||||
pub fn from_env() -> Result<Self, String> {
|
||||
fn get(name: &str) -> Option<String> {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
fn get_secret(name: &str, file_name: &str) -> Result<Option<String>, String> {
|
||||
if let Some(path) = get(file_name) {
|
||||
let raw = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
|
||||
let v = raw.trim().to_string();
|
||||
if v.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
return Ok(Some(v));
|
||||
}
|
||||
Ok(get(name))
|
||||
}
|
||||
|
||||
let endpoint = get("CONTROL_S3_ENDPOINT")
|
||||
.or_else(|| get("S3_ENDPOINT"))
|
||||
.ok_or_else(|| "Missing CONTROL_S3_ENDPOINT".to_string())?;
|
||||
let public_endpoint =
|
||||
get("CONTROL_S3_PUBLIC_ENDPOINT").or_else(|| get("S3_PUBLIC_ENDPOINT"));
|
||||
let region = get("CONTROL_S3_REGION")
|
||||
.or_else(|| get("S3_REGION"))
|
||||
.unwrap_or_else(|| "us-east-1".to_string());
|
||||
let access_key_id =
|
||||
get_secret("CONTROL_S3_ACCESS_KEY_ID", "CONTROL_S3_ACCESS_KEY_ID_FILE")?
|
||||
.or_else(|| {
|
||||
get_secret("S3_ACCESS_KEY_ID", "S3_ACCESS_KEY_ID_FILE")
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
.ok_or_else(|| "Missing CONTROL_S3_ACCESS_KEY_ID".to_string())?;
|
||||
let secret_access_key = get_secret(
|
||||
"CONTROL_S3_SECRET_ACCESS_KEY",
|
||||
"CONTROL_S3_SECRET_ACCESS_KEY_FILE",
|
||||
)?
|
||||
.or_else(|| {
|
||||
get_secret("S3_SECRET_ACCESS_KEY", "S3_SECRET_ACCESS_KEY_FILE")
|
||||
.ok()
|
||||
.flatten()
|
||||
})
|
||||
.ok_or_else(|| "Missing CONTROL_S3_SECRET_ACCESS_KEY".to_string())?;
|
||||
let force_path_style = get("CONTROL_S3_FORCE_PATH_STYLE")
|
||||
.or_else(|| get("S3_FORCE_PATH_STYLE"))
|
||||
.as_deref()
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(true);
|
||||
let insecure = get("CONTROL_S3_INSECURE")
|
||||
.or_else(|| get("S3_INSECURE"))
|
||||
.as_deref()
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
let bucket_raw = get("CONTROL_S3_BUCKET_DOCS")
|
||||
.or_else(|| get("S3_BUCKET_DOCS"))
|
||||
.ok_or_else(|| "Missing CONTROL_S3_BUCKET_DOCS".to_string())?;
|
||||
let buckets: Vec<String> = bucket_raw
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
if buckets.is_empty() {
|
||||
return Err("Missing CONTROL_S3_BUCKET_DOCS".to_string());
|
||||
}
|
||||
let prefix = get("CONTROL_S3_PREFIX_DOCS")
|
||||
.or_else(|| get("S3_PREFIX_DOCS"))
|
||||
.unwrap_or_else(|| "docs/".to_string());
|
||||
let prefix = if prefix.ends_with('/') {
|
||||
prefix
|
||||
} else {
|
||||
format!("{prefix}/")
|
||||
};
|
||||
|
||||
// SECURITY: `*_INSECURE=true` is intended for local MinIO setups that use plain HTTP.
|
||||
// We currently do not disable TLS certificate verification for HTTPS endpoints.
|
||||
if insecure && endpoint.trim_start().starts_with("https://") {
|
||||
return Err(
|
||||
"CONTROL_S3_INSECURE=true is not supported with https:// endpoints (TLS verification is not disabled). Use http:// for local MinIO, or set CONTROL_S3_INSECURE=false for production."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
public_endpoint,
|
||||
region,
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
force_path_style,
|
||||
insecure,
|
||||
buckets,
|
||||
prefix,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DocsStore {
|
||||
cfg: DocsConfig,
|
||||
client: Client,
|
||||
presign_client: Client,
|
||||
}
|
||||
|
||||
impl DocsStore {
|
||||
pub async fn new(cfg: DocsConfig) -> Result<Self, String> {
|
||||
let creds = Credentials::new(
|
||||
cfg.access_key_id.clone(),
|
||||
cfg.secret_access_key.clone(),
|
||||
None,
|
||||
None,
|
||||
"static",
|
||||
);
|
||||
let shared = aws_config::from_env()
|
||||
.region(Region::new(cfg.region.clone()))
|
||||
.credentials_provider(creds.clone())
|
||||
.endpoint_url(cfg.endpoint.clone())
|
||||
.load()
|
||||
.await;
|
||||
|
||||
let s3_conf = S3ConfigBuilder::from(&shared)
|
||||
.force_path_style(cfg.force_path_style)
|
||||
.build();
|
||||
let client = Client::from_conf(s3_conf);
|
||||
|
||||
let presign_endpoint = cfg
|
||||
.public_endpoint
|
||||
.clone()
|
||||
.unwrap_or_else(|| cfg.endpoint.clone());
|
||||
let presign_shared = aws_config::from_env()
|
||||
.region(Region::new(cfg.region.clone()))
|
||||
.credentials_provider(creds)
|
||||
.endpoint_url(presign_endpoint)
|
||||
.load()
|
||||
.await;
|
||||
let presign_conf = S3ConfigBuilder::from(&presign_shared)
|
||||
.force_path_style(cfg.force_path_style)
|
||||
.build();
|
||||
let presign_client = Client::from_conf(presign_conf);
|
||||
|
||||
Ok(Self {
|
||||
cfg,
|
||||
client,
|
||||
presign_client,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn key_for(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
doc_type: &str,
|
||||
doc_id: &str,
|
||||
filename: &str,
|
||||
) -> Result<String, String> {
|
||||
validate_segment("tenant_id", tenant_id)?;
|
||||
validate_segment("doc_type", doc_type)?;
|
||||
validate_segment("doc_id", doc_id)?;
|
||||
validate_filename(filename)?;
|
||||
Ok(format!(
|
||||
"{}{}/{}/{}/{}",
|
||||
self.cfg.prefix, tenant_id, doc_type, doc_id, filename
|
||||
))
|
||||
}
|
||||
|
||||
pub fn prefix(&self) -> &str {
|
||||
self.cfg.prefix.as_str()
|
||||
}
|
||||
|
||||
pub fn buckets(&self) -> &[String] {
|
||||
self.cfg.buckets.as_slice()
|
||||
}
|
||||
|
||||
fn bucket_for_tenant(&self, tenant_id: &str) -> &str {
|
||||
// Deterministic sharding across buckets. Note: if the bucket list changes, the mapping changes.
|
||||
// For production, set the full planned bucket set up-front (e.g. `-0,-1,-2`) to keep mapping stable.
|
||||
let n = self.cfg.buckets.len();
|
||||
if n == 1 {
|
||||
return self.cfg.buckets[0].as_str();
|
||||
}
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(tenant_id.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let mut b = [0u8; 8];
|
||||
b.copy_from_slice(&digest[..8]);
|
||||
let v = u64::from_be_bytes(b);
|
||||
let idx = (v as usize) % n;
|
||||
self.cfg.buckets[idx].as_str()
|
||||
}
|
||||
|
||||
pub fn content_hash_sha256_hex(bytes: &[u8]) -> String {
|
||||
let mut hasher = sha2::Sha256::new();
|
||||
hasher.update(bytes);
|
||||
let digest = hasher.finalize();
|
||||
let mut out = String::with_capacity(digest.len() * 2);
|
||||
for b in digest {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(&mut out, "{:02x}", b);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn put_for_tenant(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
key: &str,
|
||||
bytes: Vec<u8>,
|
||||
content_type: Option<String>,
|
||||
) -> Result<(), String> {
|
||||
let mut req = self
|
||||
.client
|
||||
.put_object()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.key(key)
|
||||
.body(aws_sdk_s3::primitives::ByteStream::from(bytes));
|
||||
if let Some(ct) = content_type {
|
||||
req = req.content_type(ct);
|
||||
}
|
||||
req.send().await.map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_bytes_for_tenant(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
key: &str,
|
||||
) -> Result<(Vec<u8>, Option<String>), String> {
|
||||
let out = self
|
||||
.client
|
||||
.get_object()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.key(key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let ct = out.content_type().map(|s| s.to_string());
|
||||
let bytes = out
|
||||
.body
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.into_bytes()
|
||||
.to_vec();
|
||||
Ok((bytes, ct))
|
||||
}
|
||||
|
||||
pub async fn delete_for_tenant(&self, tenant_id: &str, key: &str) -> Result<(), String> {
|
||||
self.client
|
||||
.delete_object()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.key(key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_for_tenant(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
prefix: &str,
|
||||
) -> Result<Vec<DocObject>, String> {
|
||||
let out = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.prefix(prefix)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let mut items = Vec::new();
|
||||
for o in out.contents() {
|
||||
if let Some(key) = o.key() {
|
||||
items.push(DocObject {
|
||||
key: key.to_string(),
|
||||
size: o.size().unwrap_or(0),
|
||||
last_modified: o.last_modified().map(|d| d.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
pub async fn ensure_buckets_exist(&self) -> Result<(), String> {
|
||||
for bucket in &self.cfg.buckets {
|
||||
let head = self.client.head_bucket().bucket(bucket).send().await;
|
||||
if head.is_ok() {
|
||||
continue;
|
||||
}
|
||||
self.client
|
||||
.create_bucket()
|
||||
.bucket(bucket)
|
||||
.acl(BucketCannedAcl::Private)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn presign_put_for_tenant(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
key: &str,
|
||||
content_type: Option<String>,
|
||||
expires: Duration,
|
||||
) -> Result<String, String> {
|
||||
let mut req = self
|
||||
.presign_client
|
||||
.put_object()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.key(key);
|
||||
if let Some(ct) = content_type {
|
||||
req = req.content_type(ct);
|
||||
}
|
||||
let presigned = req
|
||||
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(presigned.uri().to_string())
|
||||
}
|
||||
|
||||
pub async fn presign_get_for_tenant(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
key: &str,
|
||||
expires: Duration,
|
||||
) -> Result<String, String> {
|
||||
let req = self
|
||||
.presign_client
|
||||
.get_object()
|
||||
.bucket(self.bucket_for_tenant(tenant_id))
|
||||
.key(key);
|
||||
let presigned = req
|
||||
.presigned(PresigningConfig::expires_in(expires).map_err(|e| e.to_string())?)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(presigned.uri().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize)]
|
||||
pub struct DocObject {
|
||||
pub key: String,
|
||||
pub size: i64,
|
||||
pub last_modified: Option<String>,
|
||||
}
|
||||
|
||||
fn validate_segment(name: &str, value: &str) -> Result<(), String> {
|
||||
if value.is_empty() {
|
||||
return Err(format!("{name} is required"));
|
||||
}
|
||||
if value.len() > 128 {
|
||||
return Err(format!("{name} too long"));
|
||||
}
|
||||
if value.contains('/') || value.contains('\\') {
|
||||
return Err(format!("{name} contains invalid characters"));
|
||||
}
|
||||
if value.contains("..") {
|
||||
return Err(format!("{name} contains invalid characters"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_filename(value: &str) -> Result<(), String> {
|
||||
validate_segment("filename", value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
||||
LOCK.get_or_init(|| std::sync::Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_from_env_parses_expected_fields() {
|
||||
let _guard = env_lock();
|
||||
unsafe {
|
||||
std::env::set_var("CONTROL_S3_ENDPOINT", "http://minio:9000");
|
||||
std::env::set_var("CONTROL_S3_REGION", "us-east-1");
|
||||
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "minioadmin");
|
||||
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "minioadmin");
|
||||
std::env::set_var("CONTROL_S3_BUCKET_DOCS", "cloudlysis-docs");
|
||||
std::env::set_var("CONTROL_S3_PREFIX_DOCS", "docs/");
|
||||
std::env::set_var("CONTROL_S3_FORCE_PATH_STYLE", "true");
|
||||
std::env::set_var("CONTROL_S3_INSECURE", "true");
|
||||
}
|
||||
|
||||
let cfg = DocsConfig::from_env().unwrap();
|
||||
assert_eq!(cfg.endpoint, "http://minio:9000");
|
||||
assert_eq!(cfg.buckets, vec!["cloudlysis-docs".to_string()]);
|
||||
assert_eq!(cfg.prefix, "docs/");
|
||||
assert!(cfg.force_path_style);
|
||||
assert!(cfg.insecure);
|
||||
|
||||
unsafe {
|
||||
std::env::remove_var("CONTROL_S3_ENDPOINT");
|
||||
std::env::remove_var("CONTROL_S3_REGION");
|
||||
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
|
||||
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
|
||||
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
|
||||
std::env::remove_var("CONTROL_S3_PREFIX_DOCS");
|
||||
std::env::remove_var("CONTROL_S3_FORCE_PATH_STYLE");
|
||||
std::env::remove_var("CONTROL_S3_INSECURE");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_rejects_insecure_with_https_endpoint() {
|
||||
let _guard = env_lock();
|
||||
unsafe {
|
||||
std::env::set_var("CONTROL_S3_ENDPOINT", "https://s3.example.com");
|
||||
std::env::set_var("CONTROL_S3_ACCESS_KEY_ID", "a");
|
||||
std::env::set_var("CONTROL_S3_SECRET_ACCESS_KEY", "b");
|
||||
std::env::set_var(
|
||||
"CONTROL_S3_BUCKET_DOCS",
|
||||
"cloudlysis-docs-0,cloudlysis-docs-1",
|
||||
);
|
||||
std::env::set_var("CONTROL_S3_INSECURE", "true");
|
||||
}
|
||||
let err = DocsConfig::from_env().unwrap_err();
|
||||
assert!(
|
||||
err.contains("CONTROL_S3_INSECURE=true") && err.contains("https://"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
unsafe {
|
||||
std::env::remove_var("CONTROL_S3_ENDPOINT");
|
||||
std::env::remove_var("CONTROL_S3_ACCESS_KEY_ID");
|
||||
std::env::remove_var("CONTROL_S3_SECRET_ACCESS_KEY");
|
||||
std::env::remove_var("CONTROL_S3_BUCKET_DOCS");
|
||||
std::env::remove_var("CONTROL_S3_INSECURE");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn key_scheme_is_stable() {
|
||||
let cfg = DocsConfig {
|
||||
endpoint: "http://minio:9000".to_string(),
|
||||
public_endpoint: None,
|
||||
region: "us-east-1".to_string(),
|
||||
access_key_id: "x".to_string(),
|
||||
secret_access_key: "y".to_string(),
|
||||
force_path_style: true,
|
||||
insecure: true,
|
||||
buckets: vec![
|
||||
"cloudlysis-docs-0".to_string(),
|
||||
"cloudlysis-docs-1".to_string(),
|
||||
],
|
||||
prefix: "docs/".to_string(),
|
||||
};
|
||||
let store = DocsStore::new(cfg).await.unwrap();
|
||||
|
||||
let key = store
|
||||
.key_for("tenant-a", "deployments", "v1", "bundle.tar.gz")
|
||||
.unwrap();
|
||||
assert_eq!(key, "docs/tenant-a/deployments/v1/bundle.tar.gz");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn key_scheme_rejects_invalid_segments() {
|
||||
let cfg = DocsConfig {
|
||||
endpoint: "http://minio:9000".to_string(),
|
||||
public_endpoint: None,
|
||||
region: "us-east-1".to_string(),
|
||||
access_key_id: "x".to_string(),
|
||||
secret_access_key: "y".to_string(),
|
||||
force_path_style: true,
|
||||
insecure: true,
|
||||
buckets: vec!["cloudlysis-docs".to_string()],
|
||||
prefix: "docs/".to_string(),
|
||||
};
|
||||
let store = DocsStore::new(cfg).await.unwrap();
|
||||
|
||||
assert!(store.key_for("t/a", "x", "y", "z").is_err());
|
||||
assert!(store.key_for("t", "x", "../y", "z").is_err());
|
||||
assert!(store.key_for("t", "x", "y", "a/b").is_err());
|
||||
}
|
||||
}
|
||||
@@ -28,31 +28,49 @@ pub struct SwarmStateFile {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SwarmStore {
|
||||
path: std::path::PathBuf,
|
||||
inner: SwarmStoreInner,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum SwarmStoreInner {
|
||||
File { path: std::path::PathBuf },
|
||||
DockerCli,
|
||||
}
|
||||
|
||||
impl SwarmStore {
|
||||
pub fn new(path: std::path::PathBuf) -> Self {
|
||||
Self { path }
|
||||
Self {
|
||||
inner: SwarmStoreInner::File { path },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_docker_cli() -> Self {
|
||||
Self {
|
||||
inner: SwarmStoreInner::DockerCli,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_services(&self) -> Vec<SwarmService> {
|
||||
self.load().map(|s| s.services).unwrap_or_default()
|
||||
match &self.inner {
|
||||
SwarmStoreInner::File { path } => {
|
||||
load_state(path).map(|s| s.services).unwrap_or_default()
|
||||
}
|
||||
SwarmStoreInner::DockerCli => list_services_docker_cli().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_tasks(&self, service_name: &str) -> Vec<SwarmTask> {
|
||||
self.load()
|
||||
.map(|s| {
|
||||
s.tasks
|
||||
.into_iter()
|
||||
.filter(|t| t.service == service_name)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
fn load(&self) -> Option<SwarmStateFile> {
|
||||
load_state(&self.path)
|
||||
match &self.inner {
|
||||
SwarmStoreInner::File { path } => load_state(path)
|
||||
.map(|s| {
|
||||
s.tasks
|
||||
.into_iter()
|
||||
.filter(|t| t.service == service_name)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
SwarmStoreInner::DockerCli => list_tasks_docker_cli(service_name).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,3 +78,120 @@ fn load_state(path: &Path) -> Option<SwarmStateFile> {
|
||||
let raw = fs::read_to_string(path).ok()?;
|
||||
serde_json::from_str(&raw).ok()
|
||||
}
|
||||
|
||||
fn list_services_docker_cli() -> Result<Vec<SwarmService>, String> {
|
||||
let out = std::process::Command::new("docker")
|
||||
.args(["service", "ls", "--format", "{{json .}}"])
|
||||
.output()
|
||||
.map_err(|e| format!("docker exec failed: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"docker service ls failed: {}",
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ServiceRow {
|
||||
#[serde(rename = "Name")]
|
||||
name: String,
|
||||
#[serde(rename = "Image")]
|
||||
image: Option<String>,
|
||||
#[serde(rename = "Mode")]
|
||||
mode: Option<String>,
|
||||
#[serde(rename = "Replicas")]
|
||||
replicas: Option<String>,
|
||||
#[serde(rename = "UpdatedAt")]
|
||||
updated_at: Option<String>,
|
||||
}
|
||||
|
||||
let mut services = Vec::new();
|
||||
for line in String::from_utf8_lossy(&out.stdout).lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let row: ServiceRow =
|
||||
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
|
||||
services.push(SwarmService {
|
||||
name: row.name,
|
||||
image: row.image,
|
||||
mode: row.mode,
|
||||
replicas: row.replicas,
|
||||
updated_at: row.updated_at,
|
||||
});
|
||||
}
|
||||
Ok(services)
|
||||
}
|
||||
|
||||
fn list_tasks_docker_cli(service_name: &str) -> Result<Vec<SwarmTask>, String> {
|
||||
let out = std::process::Command::new("docker")
|
||||
.args([
|
||||
"service",
|
||||
"ps",
|
||||
service_name,
|
||||
"--no-trunc",
|
||||
"--format",
|
||||
"{{json .}}",
|
||||
])
|
||||
.output()
|
||||
.map_err(|e| format!("docker exec failed: {e}"))?;
|
||||
if !out.status.success() {
|
||||
return Err(format!(
|
||||
"docker service ps failed: {}",
|
||||
String::from_utf8_lossy(&out.stderr)
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TaskRow {
|
||||
#[serde(rename = "ID")]
|
||||
id: String,
|
||||
#[serde(rename = "Name")]
|
||||
name: Option<String>,
|
||||
#[serde(rename = "Node")]
|
||||
node: Option<String>,
|
||||
#[serde(rename = "DesiredState")]
|
||||
desired_state: Option<String>,
|
||||
#[serde(rename = "CurrentState")]
|
||||
current_state: Option<String>,
|
||||
#[serde(rename = "Error")]
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for line in String::from_utf8_lossy(&out.stdout).lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let row: TaskRow =
|
||||
serde_json::from_str(line).map_err(|e| format!("invalid json row: {e}"))?;
|
||||
let service = row
|
||||
.name
|
||||
.as_deref()
|
||||
.and_then(|n| n.split_once('.').map(|(svc, _)| svc.to_string()))
|
||||
.unwrap_or_else(|| service_name.to_string());
|
||||
tasks.push(SwarmTask {
|
||||
id: row.id,
|
||||
service,
|
||||
node: row.node,
|
||||
desired_state: row.desired_state,
|
||||
current_state: row.current_state,
|
||||
error: row.error,
|
||||
});
|
||||
}
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn state_file_parses() {
|
||||
let raw = r#"{"services":[{"name":"a","image":null,"mode":null,"replicas":null,"updated_at":null}],"tasks":[]}"#;
|
||||
let parsed: SwarmStateFile = serde_json::from_str(raw).unwrap();
|
||||
assert_eq!(parsed.services.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user