use crate::{ AppState, auth::{Principal, has_permission}, }; use async_trait::async_trait; use axum::{ Json, extract::{Extension, Path, State}, http::{HeaderMap, StatusCode}, response::IntoResponse, }; use serde::{Deserialize, Serialize}; use std::time::Duration; use std::{ collections::BTreeMap, fs, path::PathBuf, sync::{Arc, RwLock}, time::SystemTime, }; use thiserror::Error; use uuid::Uuid; const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID; fn verify_tenant_isolation(headers: &HeaderMap, path_tenant_id: Uuid) -> Result<(), StatusCode> { let header_tenant_id = headers .get(HEADER_TENANT_ID) .and_then(|v| v.to_str().ok()) .ok_or(StatusCode::BAD_REQUEST) .and_then(|s| Uuid::parse_str(s).map_err(|_| StatusCode::BAD_REQUEST))?; if header_tenant_id != path_tenant_id { return Err(StatusCode::FORBIDDEN); } Ok(()) } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Plan { Free, Pro, Enterprise, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum SubscriptionStatus { Trialing, Active, PastDue, Paused, Canceled, Incomplete, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Entitlements { pub max_deployments: u32, pub max_runners: u32, pub s3_docs_enabled: bool, pub support_tier: String, } #[derive(Clone, Debug, Serialize, Deserialize)] pub enum BillingEvent { SubscriptionCreated { tenant_id: Uuid, event_id: String, provider_customer_id: String, provider_subscription_id: String, status: SubscriptionStatus, plan: Plan, current_period_end: String, ts_ms: u64, }, SubscriptionUpdated { tenant_id: Uuid, event_id: String, status: SubscriptionStatus, plan: Plan, current_period_end: String, cancel_at_period_end: bool, ts_ms: u64, }, SubscriptionDeleted { tenant_id: Uuid, event_id: String, ts_ms: u64, }, } impl BillingEvent { pub fn tenant_id(&self) -> Uuid { match self { Self::SubscriptionCreated { tenant_id, .. } => *tenant_id, Self::SubscriptionUpdated { tenant_id, .. } => *tenant_id, Self::SubscriptionDeleted { tenant_id, .. } => *tenant_id, } } pub fn event_id(&self) -> &str { match self { Self::SubscriptionCreated { event_id, .. } => event_id, Self::SubscriptionUpdated { event_id, .. } => event_id, Self::SubscriptionDeleted { event_id, .. } => event_id, } } pub fn ts_ms(&self) -> u64 { match self { Self::SubscriptionCreated { ts_ms, .. } => *ts_ms, Self::SubscriptionUpdated { ts_ms, .. } => *ts_ms, Self::SubscriptionDeleted { ts_ms, .. } => *ts_ms, } } } impl Entitlements { pub fn derive(plan: Option<&Plan>, status: Option<&SubscriptionStatus>) -> Self { let is_active = matches!( status, Some(SubscriptionStatus::Trialing | SubscriptionStatus::Active) ); if !is_active { return Self { max_deployments: 1, max_runners: 1, s3_docs_enabled: false, support_tier: "community".to_string(), }; } match plan.unwrap_or(&Plan::Free) { Plan::Free => Self { max_deployments: 3, max_runners: 1, s3_docs_enabled: false, support_tier: "community".to_string(), }, Plan::Pro => Self { max_deployments: 10, max_runners: 5, s3_docs_enabled: true, support_tier: "standard".to_string(), }, Plan::Enterprise => Self { max_deployments: 1000, max_runners: 50, s3_docs_enabled: true, support_tier: "priority".to_string(), }, } } } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TenantBillingState { pub provider: String, pub provider_customer_id: Option, pub provider_subscription_id: Option, pub provider_checkout_session_id: Option, pub status: Option, pub plan: Option, pub current_period_end: Option, pub cancel_at_period_end: Option, pub processed_webhook_event_ids: Vec, pub updated_at: u64, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BillingStateFile { pub revision: Option, pub tenants: BTreeMap, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BillingResponse { pub configured: bool, pub provider: Option, pub plan: Option, pub status: Option, pub current_period_end: Option, pub cancel_at_period_end: Option, pub entitlements: Entitlements, } #[derive(Clone)] pub struct BillingStore { inner: Arc>, } struct Inner { path: PathBuf, last_modified: Option, cached: Option, } impl BillingStore { pub fn new(path: PathBuf) -> Self { Self { inner: Arc::new(RwLock::new(Inner { path, last_modified: None, cached: None, })), } } pub fn get_for_tenant(&self, tenant_id: Uuid) -> BillingResponse { let mut inner = self.inner.write().expect("billing lock poisoned"); inner.reload_if_changed(); if let Some(state) = inner .cached .as_ref() .and_then(|file| file.tenants.get(&tenant_id)) { return BillingResponse { configured: true, provider: Some(state.provider.clone()), plan: state.plan.clone(), status: state.status.clone(), current_period_end: state.current_period_end.clone(), cancel_at_period_end: state.cancel_at_period_end, entitlements: Entitlements::derive(state.plan.as_ref(), state.status.as_ref()), }; } BillingResponse { configured: false, provider: None, plan: None, status: None, current_period_end: None, cancel_at_period_end: None, entitlements: Entitlements::derive(None, None), } } pub fn get_all_tenant_ids(&self) -> Vec { let mut inner = self.inner.write().expect("billing lock poisoned"); inner.reload_if_changed(); inner .cached .as_ref() .map(|f| f.tenants.keys().cloned().collect()) .unwrap_or_default() } pub fn get_subscription_id(&self, tenant_id: Uuid) -> Option { let mut inner = self.inner.write().expect("billing lock poisoned"); inner.reload_if_changed(); inner .cached .as_ref() .and_then(|f| f.tenants.get(&tenant_id)) .and_then(|s| s.provider_subscription_id.clone()) } pub fn apply_event(&self, event: BillingEvent) -> Result<(), String> { let mut inner = self.inner.write().expect("billing lock poisoned"); inner.reload_if_changed(); let mut file = inner.cached.clone().unwrap_or(BillingStateFile { revision: Some("dev".to_string()), tenants: BTreeMap::new(), }); let tenant_id = event.tenant_id(); let event_id = event.event_id().to_string(); let ts_ms = event.ts_ms(); let state = file.tenants.entry(tenant_id).or_insert(TenantBillingState { provider: "unknown".to_string(), // Will be updated by Created event provider_customer_id: None, provider_subscription_id: None, provider_checkout_session_id: None, status: None, plan: None, current_period_end: None, cancel_at_period_end: None, processed_webhook_event_ids: vec![], updated_at: 0, }); // Deduplication if state.processed_webhook_event_ids.contains(&event_id) { return Ok(()); } // Monotonicity check if state.updated_at > ts_ms { state.processed_webhook_event_ids.push(event_id); state.processed_webhook_event_ids.truncate(50); inner.save(file)?; return Ok(()); } match event { BillingEvent::SubscriptionCreated { provider_customer_id, provider_subscription_id, status, plan, current_period_end, .. } => { state.provider_customer_id = Some(provider_customer_id); state.provider_subscription_id = Some(provider_subscription_id); state.status = Some(status); state.plan = Some(plan); state.current_period_end = Some(current_period_end); } BillingEvent::SubscriptionUpdated { status, plan, current_period_end, cancel_at_period_end, .. } => { state.status = Some(status); state.plan = Some(plan); state.current_period_end = Some(current_period_end); state.cancel_at_period_end = Some(cancel_at_period_end); } BillingEvent::SubscriptionDeleted { .. } => { state.status = Some(SubscriptionStatus::Canceled); } } state.updated_at = ts_ms; state.processed_webhook_event_ids.push(event_id); state.processed_webhook_event_ids.truncate(50); inner.save(file)?; Ok(()) } #[cfg(test)] pub fn update_tenant_state( &self, tenant_id: Uuid, state: TenantBillingState, ) -> Result { let mut inner = self.inner.write().expect("billing lock poisoned"); inner.reload_if_changed(); let mut file = inner.cached.clone().unwrap_or(BillingStateFile { revision: Some("dev".to_string()), tenants: BTreeMap::new(), }); file.tenants.insert(tenant_id, state); inner.save(file) } } impl Inner { fn save(&mut self, mut file: BillingStateFile) -> Result { let revision = format!("rev-{}", Uuid::new_v4()); file.revision = Some(revision.clone()); let raw = serde_json::to_string_pretty(&file).map_err(|e| e.to_string())?; let tmp = self.path.with_extension("json.tmp"); fs::write(&tmp, raw).map_err(|e| e.to_string())?; fs::rename(&tmp, &self.path).map_err(|e| e.to_string())?; self.last_modified = None; self.cached = Some(file); Ok(revision) } fn reload_if_changed(&mut self) { let meta = fs::metadata(&self.path).ok(); let modified = meta.and_then(|m| m.modified().ok()); if self.cached.is_some() && modified.is_some() && modified == self.last_modified { return; } self.last_modified = modified; let p = &self.path; self.cached = fs::read_to_string(p) .ok() .and_then(|raw| serde_json::from_str(&raw).ok()); } } pub async fn get_billing( State(state): State, Path(tenant_id): Path, headers: HeaderMap, Extension(principal): Extension, ) -> impl IntoResponse { if !has_permission(&principal, "control:read") { return StatusCode::FORBIDDEN.into_response(); } if let Err(status) = verify_tenant_isolation(&headers, tenant_id) { return status.into_response(); } let resp = state.billing.get_for_tenant(tenant_id); (StatusCode::OK, Json(resp)).into_response() } #[derive(Debug, Deserialize)] pub struct CheckoutRequest { pub plan: Plan, pub return_path: Option, } pub async fn checkout( State(state): State, Path(tenant_id): Path, headers: HeaderMap, Extension(principal): Extension, Json(body): Json, ) -> impl IntoResponse { if !has_permission(&principal, "control:write") { return StatusCode::FORBIDDEN.into_response(); } if let Err(status) = verify_tenant_isolation(&headers, tenant_id) { return status.into_response(); } // Check if subscription already exists and is active/trialing let current = state.billing.get_for_tenant(tenant_id); if current.configured && matches!( current.status, Some(SubscriptionStatus::Active | SubscriptionStatus::Trialing) ) { return ( StatusCode::CONFLICT, Json(serde_json::json!({ "error": "tenant already has an active subscription" })), ) .into_response(); } // Construct full return URL // TODO: Validate return_path against ALLOWED_RETURN_ORIGINS if provided let return_url = body.return_path.unwrap_or_else(|| "/billing".to_string()); match state .billing_provider .create_checkout_session(tenant_id, body.plan, return_url) .await { Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(), Err(e) => { let err_msg = e.to_string(); ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": err_msg })), ) .into_response() } } } pub async fn portal( State(state): State, Path(tenant_id): Path, headers: HeaderMap, Extension(principal): Extension, ) -> impl IntoResponse { if !has_permission(&principal, "control:write") { return StatusCode::FORBIDDEN.into_response(); } if let Err(status) = verify_tenant_isolation(&headers, tenant_id) { return status.into_response(); } let return_url = "/billing".to_string(); match state .billing_provider .create_portal_session(tenant_id, return_url) .await { Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(), Err(e) => { let err_msg = e.to_string(); ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": err_msg })), ) .into_response() } } } pub async fn webhook( State(state): State, Path(_provider): Path, headers: HeaderMap, body: axum::body::Bytes, ) -> impl IntoResponse { // Note: We don't require auth here as this is a public endpoint called by the provider. // Security is handled via signature verification in the provider trait. match state.billing_provider.verify_webhook(&body, &headers).await { Ok(event) => { metrics::counter!("billing_webhook_requests_total", "status" => "success").increment(1); if let Err(e) = state.billing.apply_event(event) { tracing::error!(error = %e, "failed to apply billing event from webhook"); return ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e })), ) .into_response(); } StatusCode::OK.into_response() } Err(e) => { metrics::counter!("billing_webhook_requests_total", "status" => "error").increment(1); ( StatusCode::BAD_REQUEST, Json(serde_json::json!({ "error": e.to_string() })), ) .into_response() } } } pub async fn run_reconciliation_loop(state: AppState) { let interval_secs = std::env::var("CONTROL_BILLING_RECONCILE_INTERVAL_SECS") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(3600); tracing::info!(interval_secs, "starting billing reconciliation loop"); loop { tokio::time::sleep(Duration::from_secs(interval_secs)).await; tracing::info!("starting billing reconciliation run"); reconcile_once(&state).await; // Update tenant status gauges // Note: This is an expensive operation if there are many tenants, // but for reconciliation it's fine once per hour. update_billing_gauges(&state); } } pub async fn reconcile_once(state: &AppState) { let start = std::time::Instant::now(); let tenant_ids = state.billing.get_all_tenant_ids(); let mut success = 0; let mut error = 0; let mut skipped = 0; for tenant_id in tenant_ids { let sub_id = state.billing.get_subscription_id(tenant_id); if let Some(subscription_id) = sub_id { match state .billing_provider .fetch_subscription(tenant_id, &subscription_id) .await { Ok(event) => { if let Err(e) = state.billing.apply_event(event) { tracing::error!(?tenant_id, error = %e, "failed to apply reconciled billing event"); error += 1; } else { success += 1; } } Err(e) => { tracing::error!(?tenant_id, error = %e, "failed to fetch subscription for reconciliation"); error += 1; } } } else { skipped += 1; } } let elapsed = start.elapsed(); metrics::counter!("billing_reconciliation_runs_total", "result" => "done").increment(1); metrics::histogram!("billing_reconciliation_duration_ms").record(elapsed.as_millis() as f64); tracing::info!( success, error, skipped, duration_ms = elapsed.as_millis(), "billing reconciliation run complete" ); } fn update_billing_gauges(state: &AppState) { let tenant_ids = state.billing.get_all_tenant_ids(); let mut counts: BTreeMap<(String, String), u64> = BTreeMap::new(); for tenant_id in tenant_ids { let resp = state.billing.get_for_tenant(tenant_id); let plan = match resp.plan { Some(Plan::Free) => "free", Some(Plan::Pro) => "pro", Some(Plan::Enterprise) => "enterprise", None => "none", } .to_string(); let status = match resp.status { Some(SubscriptionStatus::Active) => "active", Some(SubscriptionStatus::Trialing) => "trialing", Some(SubscriptionStatus::PastDue) => "past_due", Some(SubscriptionStatus::Paused) => "paused", Some(SubscriptionStatus::Canceled) => "canceled", Some(SubscriptionStatus::Incomplete) => "incomplete", None => "none", } .to_string(); *counts.entry((plan, status)).or_insert(0) += 1; } for ((plan, status), count) in counts { metrics::gauge!("billing_tenant_status_count", "plan" => plan, "status" => status) .set(count as f64); } } #[derive(Debug, Error)] pub enum BillingError { #[error("provider error: {0}")] Provider(String), #[error("invalid configuration: {0}")] Config(String), } #[async_trait] pub trait BillingProvider: Send + Sync { async fn create_checkout_session( &self, tenant_id: Uuid, plan: Plan, return_url: String, ) -> Result; async fn create_portal_session( &self, tenant_id: Uuid, return_url: String, ) -> Result; async fn verify_webhook( &self, payload: &[u8], headers: &HeaderMap, ) -> Result; async fn fetch_subscription( &self, tenant_id: Uuid, subscription_id: &str, ) -> Result; } pub struct StripeProvider { pub secret_key: String, pub price_pro: String, pub price_enterprise: String, } #[async_trait] impl BillingProvider for StripeProvider { async fn create_checkout_session( &self, tenant_id: Uuid, plan: Plan, _return_url: String, ) -> Result { let _price = match plan { Plan::Pro => &self.price_pro, Plan::Enterprise => &self.price_enterprise, Plan::Free => { return Err(BillingError::Config( "Free plan has no checkout".to_string(), )); } }; // TODO: Actually call Stripe API // For now, returning a simulated Stripe checkout URL Ok(format!( "https://checkout.stripe.com/pay/cs_test_{}?tenant_id={}", Uuid::new_v4(), tenant_id )) } async fn create_portal_session( &self, tenant_id: Uuid, _return_url: String, ) -> Result { // TODO: Actually call Stripe API Ok(format!( "https://billing.stripe.com/p/session/ps_test_{}?tenant_id={}", Uuid::new_v4(), tenant_id )) } async fn verify_webhook( &self, _payload: &[u8], _headers: &HeaderMap, ) -> Result { // TODO: Implement real Stripe signature verification Err(BillingError::Provider("Not implemented".to_string())) } async fn fetch_subscription( &self, _tenant_id: Uuid, _subscription_id: &str, ) -> Result { // TODO: Actually call Stripe API with timeout // let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()... Err(BillingError::Provider("Not implemented".to_string())) } } pub struct MockProvider; #[async_trait] impl BillingProvider for MockProvider { async fn create_checkout_session( &self, tenant_id: Uuid, _plan: Plan, _return_url: String, ) -> Result { Ok(format!("https://mock.stripe.com/checkout/{}", tenant_id)) } async fn create_portal_session( &self, tenant_id: Uuid, _return_url: String, ) -> Result { Ok(format!("https://mock.stripe.com/portal/{}", tenant_id)) } async fn verify_webhook( &self, payload: &[u8], _headers: &HeaderMap, ) -> Result { // Mock implementation: just parse the payload as a BillingEvent serde_json::from_slice(payload).map_err(|e| BillingError::Provider(e.to_string())) } async fn fetch_subscription( &self, tenant_id: Uuid, _subscription_id: &str, ) -> Result { // Mock implementation: return a SubscriptionUpdated event with current state // In a real mock we might want to store expectations, but for now we just return something plausible. Ok(BillingEvent::SubscriptionUpdated { tenant_id, event_id: format!("reconcile-{}", Uuid::new_v4()), status: SubscriptionStatus::Active, plan: Plan::Pro, current_period_end: "2099-12-31T23:59:59Z".to_string(), cancel_at_period_end: false, ts_ms: SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_millis() as u64, }) } } impl MockProvider { pub fn get_checkout_url(tenant: Uuid) -> String { format!("https://mock.stripe.com/checkout/{}", tenant) } } #[cfg(test)] mod tests { use super::*; use std::env::temp_dir; #[test] fn test_entitlement_derivation() { let e = Entitlements::derive(Some(&Plan::Free), Some(&SubscriptionStatus::PastDue)); assert_eq!(e.max_deployments, 1); let e = Entitlements::derive(Some(&Plan::Pro), Some(&SubscriptionStatus::Active)); assert_eq!(e.max_deployments, 10); assert!(e.s3_docs_enabled); let e = Entitlements::derive(Some(&Plan::Enterprise), Some(&SubscriptionStatus::Trialing)); assert_eq!(e.max_deployments, 1000); } #[test] fn test_billing_state_roundtrip() { let mut path = temp_dir(); path.push(format!("billing-{}.json", Uuid::new_v4())); let store = BillingStore::new(path.clone()); let tenant_id = Uuid::new_v4(); let resp = store.get_for_tenant(tenant_id); assert!(!resp.configured); assert_eq!(resp.entitlements.max_deployments, 1); let state = TenantBillingState { provider: "mock".to_string(), provider_customer_id: None, provider_subscription_id: None, provider_checkout_session_id: None, status: Some(SubscriptionStatus::Active), plan: Some(Plan::Pro), current_period_end: None, cancel_at_period_end: Some(false), processed_webhook_event_ids: vec![], updated_at: 0, }; store.update_tenant_state(tenant_id, state).unwrap(); let resp2 = store.get_for_tenant(tenant_id); assert!(resp2.configured); assert_eq!(resp2.provider.as_deref(), Some("mock")); assert_eq!(resp2.plan, Some(Plan::Pro)); assert_eq!(resp2.entitlements.max_deployments, 10); let _ = fs::remove_file(path); } #[tokio::test] async fn test_reconciliation_corrects_state() { let mut path = temp_dir(); path.push(format!("billing-reconcile-{}.json", Uuid::new_v4())); let store = BillingStore::new(path.clone()); let tenant_id = Uuid::new_v4(); // 1. Initial state: PastDue store .update_tenant_state( tenant_id, TenantBillingState { provider: "mock".to_string(), provider_customer_id: Some("cus_1".to_string()), provider_subscription_id: Some("sub_1".to_string()), provider_checkout_session_id: None, status: Some(SubscriptionStatus::PastDue), plan: Some(Plan::Pro), current_period_end: None, cancel_at_period_end: Some(false), processed_webhook_event_ids: vec![], updated_at: 100, }, ) .unwrap(); let state = AppState { prometheus: crate::get_test_prometheus_handle(), auth: crate::AuthConfig { hs256_secret: None }, jobs: crate::jobs::JobStore::default(), audit: crate::AuditStore::default(), tenant_locks: crate::job_engine::TenantLocks::default(), config_locks: crate::job_engine::ConfigLocks::default(), http: reqwest::Client::new(), placement: crate::placement::PlacementStore::new(temp_dir().join("placement.json")), billing: store.clone(), billing_provider: Arc::new(MockProvider), billing_enforcement_enabled: true, config: crate::config_registry::ConfigRegistry::new(None, None), fleet_services: vec![], swarm: crate::swarm::SwarmStore::new(temp_dir().join("swarm.json")), docs: None, }; // 2. Run reconciliation. MockProvider returns Active status. reconcile_once(&state).await; // 3. Verify state is now Active let resp = store.get_for_tenant(tenant_id); assert_eq!(resp.status, Some(SubscriptionStatus::Active)); let _ = fs::remove_file(path); } }