905 lines
28 KiB
Rust
905 lines
28 KiB
Rust
use crate::{
|
|
AppState,
|
|
auth::{Principal, has_permission},
|
|
};
|
|
use async_trait::async_trait;
|
|
use axum::{
|
|
Json,
|
|
extract::{Extension, Path, State},
|
|
http::{HeaderMap, StatusCode},
|
|
response::IntoResponse,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::time::Duration;
|
|
use std::{
|
|
collections::BTreeMap,
|
|
fs,
|
|
path::PathBuf,
|
|
sync::{Arc, RwLock},
|
|
time::SystemTime,
|
|
};
|
|
use thiserror::Error;
|
|
use uuid::Uuid;
|
|
|
|
const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID;
|
|
|
|
fn verify_tenant_isolation(headers: &HeaderMap, path_tenant_id: Uuid) -> Result<(), StatusCode> {
|
|
let header_tenant_id = headers
|
|
.get(HEADER_TENANT_ID)
|
|
.and_then(|v| v.to_str().ok())
|
|
.ok_or(StatusCode::BAD_REQUEST)
|
|
.and_then(|s| Uuid::parse_str(s).map_err(|_| StatusCode::BAD_REQUEST))?;
|
|
|
|
if header_tenant_id != path_tenant_id {
|
|
return Err(StatusCode::FORBIDDEN);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum Plan {
|
|
Free,
|
|
Pro,
|
|
Enterprise,
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum SubscriptionStatus {
|
|
Trialing,
|
|
Active,
|
|
PastDue,
|
|
Paused,
|
|
Canceled,
|
|
Incomplete,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct Entitlements {
|
|
pub max_deployments: u32,
|
|
pub max_runners: u32,
|
|
pub s3_docs_enabled: bool,
|
|
pub support_tier: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub enum BillingEvent {
|
|
SubscriptionCreated {
|
|
tenant_id: Uuid,
|
|
event_id: String,
|
|
provider_customer_id: String,
|
|
provider_subscription_id: String,
|
|
status: SubscriptionStatus,
|
|
plan: Plan,
|
|
current_period_end: String,
|
|
ts_ms: u64,
|
|
},
|
|
SubscriptionUpdated {
|
|
tenant_id: Uuid,
|
|
event_id: String,
|
|
status: SubscriptionStatus,
|
|
plan: Plan,
|
|
current_period_end: String,
|
|
cancel_at_period_end: bool,
|
|
ts_ms: u64,
|
|
},
|
|
SubscriptionDeleted {
|
|
tenant_id: Uuid,
|
|
event_id: String,
|
|
ts_ms: u64,
|
|
},
|
|
}
|
|
|
|
impl BillingEvent {
|
|
pub fn tenant_id(&self) -> Uuid {
|
|
match self {
|
|
Self::SubscriptionCreated { tenant_id, .. } => *tenant_id,
|
|
Self::SubscriptionUpdated { tenant_id, .. } => *tenant_id,
|
|
Self::SubscriptionDeleted { tenant_id, .. } => *tenant_id,
|
|
}
|
|
}
|
|
|
|
pub fn event_id(&self) -> &str {
|
|
match self {
|
|
Self::SubscriptionCreated { event_id, .. } => event_id,
|
|
Self::SubscriptionUpdated { event_id, .. } => event_id,
|
|
Self::SubscriptionDeleted { event_id, .. } => event_id,
|
|
}
|
|
}
|
|
|
|
pub fn ts_ms(&self) -> u64 {
|
|
match self {
|
|
Self::SubscriptionCreated { ts_ms, .. } => *ts_ms,
|
|
Self::SubscriptionUpdated { ts_ms, .. } => *ts_ms,
|
|
Self::SubscriptionDeleted { ts_ms, .. } => *ts_ms,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Entitlements {
|
|
pub fn derive(plan: Option<&Plan>, status: Option<&SubscriptionStatus>) -> Self {
|
|
let is_active = matches!(
|
|
status,
|
|
Some(SubscriptionStatus::Trialing | SubscriptionStatus::Active)
|
|
);
|
|
|
|
if !is_active {
|
|
return Self {
|
|
max_deployments: 1,
|
|
max_runners: 1,
|
|
s3_docs_enabled: false,
|
|
support_tier: "community".to_string(),
|
|
};
|
|
}
|
|
|
|
match plan.unwrap_or(&Plan::Free) {
|
|
Plan::Free => Self {
|
|
max_deployments: 3,
|
|
max_runners: 1,
|
|
s3_docs_enabled: false,
|
|
support_tier: "community".to_string(),
|
|
},
|
|
Plan::Pro => Self {
|
|
max_deployments: 10,
|
|
max_runners: 5,
|
|
s3_docs_enabled: true,
|
|
support_tier: "standard".to_string(),
|
|
},
|
|
Plan::Enterprise => Self {
|
|
max_deployments: 1000,
|
|
max_runners: 50,
|
|
s3_docs_enabled: true,
|
|
support_tier: "priority".to_string(),
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct TenantBillingState {
|
|
pub provider: String,
|
|
pub provider_customer_id: Option<String>,
|
|
pub provider_subscription_id: Option<String>,
|
|
pub provider_checkout_session_id: Option<String>,
|
|
pub status: Option<SubscriptionStatus>,
|
|
pub plan: Option<Plan>,
|
|
pub current_period_end: Option<String>,
|
|
pub cancel_at_period_end: Option<bool>,
|
|
pub processed_webhook_event_ids: Vec<String>,
|
|
pub updated_at: u64,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct BillingStateFile {
|
|
pub revision: Option<String>,
|
|
pub tenants: BTreeMap<Uuid, TenantBillingState>,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct BillingResponse {
|
|
pub configured: bool,
|
|
pub provider: Option<String>,
|
|
pub plan: Option<Plan>,
|
|
pub status: Option<SubscriptionStatus>,
|
|
pub current_period_end: Option<String>,
|
|
pub cancel_at_period_end: Option<bool>,
|
|
pub entitlements: Entitlements,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct BillingStore {
|
|
inner: Arc<RwLock<Inner>>,
|
|
}
|
|
|
|
struct Inner {
|
|
path: PathBuf,
|
|
last_modified: Option<SystemTime>,
|
|
cached: Option<BillingStateFile>,
|
|
}
|
|
|
|
impl BillingStore {
|
|
pub fn new(path: PathBuf) -> Self {
|
|
Self {
|
|
inner: Arc::new(RwLock::new(Inner {
|
|
path,
|
|
last_modified: None,
|
|
cached: None,
|
|
})),
|
|
}
|
|
}
|
|
|
|
pub fn get_for_tenant(&self, tenant_id: Uuid) -> BillingResponse {
|
|
let mut inner = self.inner.write().expect("billing lock poisoned");
|
|
inner.reload_if_changed();
|
|
|
|
if let Some(state) = inner
|
|
.cached
|
|
.as_ref()
|
|
.and_then(|file| file.tenants.get(&tenant_id))
|
|
{
|
|
return BillingResponse {
|
|
configured: true,
|
|
provider: Some(state.provider.clone()),
|
|
plan: state.plan.clone(),
|
|
status: state.status.clone(),
|
|
current_period_end: state.current_period_end.clone(),
|
|
cancel_at_period_end: state.cancel_at_period_end,
|
|
entitlements: Entitlements::derive(state.plan.as_ref(), state.status.as_ref()),
|
|
};
|
|
}
|
|
|
|
BillingResponse {
|
|
configured: false,
|
|
provider: None,
|
|
plan: None,
|
|
status: None,
|
|
current_period_end: None,
|
|
cancel_at_period_end: None,
|
|
entitlements: Entitlements::derive(None, None),
|
|
}
|
|
}
|
|
|
|
pub fn get_all_tenant_ids(&self) -> Vec<Uuid> {
|
|
let mut inner = self.inner.write().expect("billing lock poisoned");
|
|
inner.reload_if_changed();
|
|
|
|
inner
|
|
.cached
|
|
.as_ref()
|
|
.map(|f| f.tenants.keys().cloned().collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
pub fn get_subscription_id(&self, tenant_id: Uuid) -> Option<String> {
|
|
let mut inner = self.inner.write().expect("billing lock poisoned");
|
|
inner.reload_if_changed();
|
|
|
|
inner
|
|
.cached
|
|
.as_ref()
|
|
.and_then(|f| f.tenants.get(&tenant_id))
|
|
.and_then(|s| s.provider_subscription_id.clone())
|
|
}
|
|
|
|
pub fn apply_event(&self, event: BillingEvent) -> Result<(), String> {
|
|
let mut inner = self.inner.write().expect("billing lock poisoned");
|
|
inner.reload_if_changed();
|
|
|
|
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
|
|
revision: Some("dev".to_string()),
|
|
tenants: BTreeMap::new(),
|
|
});
|
|
|
|
let tenant_id = event.tenant_id();
|
|
let event_id = event.event_id().to_string();
|
|
let ts_ms = event.ts_ms();
|
|
|
|
let state = file.tenants.entry(tenant_id).or_insert(TenantBillingState {
|
|
provider: "unknown".to_string(), // Will be updated by Created event
|
|
provider_customer_id: None,
|
|
provider_subscription_id: None,
|
|
provider_checkout_session_id: None,
|
|
status: None,
|
|
plan: None,
|
|
current_period_end: None,
|
|
cancel_at_period_end: None,
|
|
processed_webhook_event_ids: vec![],
|
|
updated_at: 0,
|
|
});
|
|
|
|
// Deduplication
|
|
if state.processed_webhook_event_ids.contains(&event_id) {
|
|
return Ok(());
|
|
}
|
|
|
|
// Monotonicity check
|
|
if state.updated_at > ts_ms {
|
|
state.processed_webhook_event_ids.push(event_id);
|
|
state.processed_webhook_event_ids.truncate(50);
|
|
inner.save(file)?;
|
|
return Ok(());
|
|
}
|
|
|
|
match event {
|
|
BillingEvent::SubscriptionCreated {
|
|
provider_customer_id,
|
|
provider_subscription_id,
|
|
status,
|
|
plan,
|
|
current_period_end,
|
|
..
|
|
} => {
|
|
state.provider_customer_id = Some(provider_customer_id);
|
|
state.provider_subscription_id = Some(provider_subscription_id);
|
|
state.status = Some(status);
|
|
state.plan = Some(plan);
|
|
state.current_period_end = Some(current_period_end);
|
|
}
|
|
BillingEvent::SubscriptionUpdated {
|
|
status,
|
|
plan,
|
|
current_period_end,
|
|
cancel_at_period_end,
|
|
..
|
|
} => {
|
|
state.status = Some(status);
|
|
state.plan = Some(plan);
|
|
state.current_period_end = Some(current_period_end);
|
|
state.cancel_at_period_end = Some(cancel_at_period_end);
|
|
}
|
|
BillingEvent::SubscriptionDeleted { .. } => {
|
|
state.status = Some(SubscriptionStatus::Canceled);
|
|
}
|
|
}
|
|
|
|
state.updated_at = ts_ms;
|
|
state.processed_webhook_event_ids.push(event_id);
|
|
state.processed_webhook_event_ids.truncate(50);
|
|
|
|
inner.save(file)?;
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub fn update_tenant_state(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
state: TenantBillingState,
|
|
) -> Result<String, String> {
|
|
let mut inner = self.inner.write().expect("billing lock poisoned");
|
|
inner.reload_if_changed();
|
|
|
|
let mut file = inner.cached.clone().unwrap_or(BillingStateFile {
|
|
revision: Some("dev".to_string()),
|
|
tenants: BTreeMap::new(),
|
|
});
|
|
|
|
file.tenants.insert(tenant_id, state);
|
|
inner.save(file)
|
|
}
|
|
}
|
|
|
|
impl Inner {
|
|
fn save(&mut self, mut file: BillingStateFile) -> Result<String, String> {
|
|
let revision = format!("rev-{}", Uuid::new_v4());
|
|
file.revision = Some(revision.clone());
|
|
|
|
let raw = serde_json::to_string_pretty(&file).map_err(|e| e.to_string())?;
|
|
let tmp = self.path.with_extension("json.tmp");
|
|
fs::write(&tmp, raw).map_err(|e| e.to_string())?;
|
|
fs::rename(&tmp, &self.path).map_err(|e| e.to_string())?;
|
|
|
|
self.last_modified = None;
|
|
self.cached = Some(file);
|
|
|
|
Ok(revision)
|
|
}
|
|
|
|
fn reload_if_changed(&mut self) {
|
|
let meta = fs::metadata(&self.path).ok();
|
|
let modified = meta.and_then(|m| m.modified().ok());
|
|
|
|
if self.cached.is_some() && modified.is_some() && modified == self.last_modified {
|
|
return;
|
|
}
|
|
|
|
self.last_modified = modified;
|
|
let p = &self.path;
|
|
self.cached = fs::read_to_string(p)
|
|
.ok()
|
|
.and_then(|raw| serde_json::from_str(&raw).ok());
|
|
}
|
|
}
|
|
|
|
pub async fn get_billing(
|
|
State(state): State<AppState>,
|
|
Path(tenant_id): Path<Uuid>,
|
|
headers: HeaderMap,
|
|
Extension(principal): Extension<Principal>,
|
|
) -> impl IntoResponse {
|
|
if !has_permission(&principal, "control:read") {
|
|
return StatusCode::FORBIDDEN.into_response();
|
|
}
|
|
|
|
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
|
return status.into_response();
|
|
}
|
|
|
|
let resp = state.billing.get_for_tenant(tenant_id);
|
|
(StatusCode::OK, Json(resp)).into_response()
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct CheckoutRequest {
|
|
pub plan: Plan,
|
|
pub return_path: Option<String>,
|
|
}
|
|
|
|
pub async fn checkout(
|
|
State(state): State<AppState>,
|
|
Path(tenant_id): Path<Uuid>,
|
|
headers: HeaderMap,
|
|
Extension(principal): Extension<Principal>,
|
|
Json(body): Json<CheckoutRequest>,
|
|
) -> impl IntoResponse {
|
|
if !has_permission(&principal, "control:write") {
|
|
return StatusCode::FORBIDDEN.into_response();
|
|
}
|
|
|
|
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
|
return status.into_response();
|
|
}
|
|
|
|
// Check if subscription already exists and is active/trialing
|
|
let current = state.billing.get_for_tenant(tenant_id);
|
|
if current.configured
|
|
&& matches!(
|
|
current.status,
|
|
Some(SubscriptionStatus::Active | SubscriptionStatus::Trialing)
|
|
)
|
|
{
|
|
return (
|
|
StatusCode::CONFLICT,
|
|
Json(serde_json::json!({ "error": "tenant already has an active subscription" })),
|
|
)
|
|
.into_response();
|
|
}
|
|
|
|
// Construct full return URL
|
|
// TODO: Validate return_path against ALLOWED_RETURN_ORIGINS if provided
|
|
let return_url = body.return_path.unwrap_or_else(|| "/billing".to_string());
|
|
|
|
match state
|
|
.billing_provider
|
|
.create_checkout_session(tenant_id, body.plan, return_url)
|
|
.await
|
|
{
|
|
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
|
|
Err(e) => {
|
|
let err_msg = e.to_string();
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(serde_json::json!({ "error": err_msg })),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn portal(
|
|
State(state): State<AppState>,
|
|
Path(tenant_id): Path<Uuid>,
|
|
headers: HeaderMap,
|
|
Extension(principal): Extension<Principal>,
|
|
) -> impl IntoResponse {
|
|
if !has_permission(&principal, "control:write") {
|
|
return StatusCode::FORBIDDEN.into_response();
|
|
}
|
|
|
|
if let Err(status) = verify_tenant_isolation(&headers, tenant_id) {
|
|
return status.into_response();
|
|
}
|
|
|
|
let return_url = "/billing".to_string();
|
|
match state
|
|
.billing_provider
|
|
.create_portal_session(tenant_id, return_url)
|
|
.await
|
|
{
|
|
Ok(url) => (StatusCode::OK, Json(serde_json::json!({ "url": url }))).into_response(),
|
|
Err(e) => {
|
|
let err_msg = e.to_string();
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(serde_json::json!({ "error": err_msg })),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn webhook(
|
|
State(state): State<AppState>,
|
|
Path(_provider): Path<String>,
|
|
headers: HeaderMap,
|
|
body: axum::body::Bytes,
|
|
) -> impl IntoResponse {
|
|
// Note: We don't require auth here as this is a public endpoint called by the provider.
|
|
// Security is handled via signature verification in the provider trait.
|
|
|
|
match state.billing_provider.verify_webhook(&body, &headers).await {
|
|
Ok(event) => {
|
|
metrics::counter!("billing_webhook_requests_total", "status" => "success").increment(1);
|
|
if let Err(e) = state.billing.apply_event(event) {
|
|
tracing::error!(error = %e, "failed to apply billing event from webhook");
|
|
return (
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(serde_json::json!({ "error": e })),
|
|
)
|
|
.into_response();
|
|
}
|
|
StatusCode::OK.into_response()
|
|
}
|
|
Err(e) => {
|
|
metrics::counter!("billing_webhook_requests_total", "status" => "error").increment(1);
|
|
(
|
|
StatusCode::BAD_REQUEST,
|
|
Json(serde_json::json!({ "error": e.to_string() })),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn run_reconciliation_loop(state: AppState) {
|
|
let interval_secs = std::env::var("CONTROL_BILLING_RECONCILE_INTERVAL_SECS")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(3600);
|
|
|
|
tracing::info!(interval_secs, "starting billing reconciliation loop");
|
|
|
|
loop {
|
|
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
|
|
|
tracing::info!("starting billing reconciliation run");
|
|
reconcile_once(&state).await;
|
|
|
|
// Update tenant status gauges
|
|
// Note: This is an expensive operation if there are many tenants,
|
|
// but for reconciliation it's fine once per hour.
|
|
update_billing_gauges(&state);
|
|
}
|
|
}
|
|
|
|
pub async fn reconcile_once(state: &AppState) {
|
|
let start = std::time::Instant::now();
|
|
|
|
let tenant_ids = state.billing.get_all_tenant_ids();
|
|
let mut success = 0;
|
|
let mut error = 0;
|
|
let mut skipped = 0;
|
|
|
|
for tenant_id in tenant_ids {
|
|
let sub_id = state.billing.get_subscription_id(tenant_id);
|
|
if let Some(subscription_id) = sub_id {
|
|
match state
|
|
.billing_provider
|
|
.fetch_subscription(tenant_id, &subscription_id)
|
|
.await
|
|
{
|
|
Ok(event) => {
|
|
if let Err(e) = state.billing.apply_event(event) {
|
|
tracing::error!(?tenant_id, error = %e, "failed to apply reconciled billing event");
|
|
error += 1;
|
|
} else {
|
|
success += 1;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::error!(?tenant_id, error = %e, "failed to fetch subscription for reconciliation");
|
|
error += 1;
|
|
}
|
|
}
|
|
} else {
|
|
skipped += 1;
|
|
}
|
|
}
|
|
|
|
let elapsed = start.elapsed();
|
|
metrics::counter!("billing_reconciliation_runs_total", "result" => "done").increment(1);
|
|
metrics::histogram!("billing_reconciliation_duration_ms").record(elapsed.as_millis() as f64);
|
|
|
|
tracing::info!(
|
|
success,
|
|
error,
|
|
skipped,
|
|
duration_ms = elapsed.as_millis(),
|
|
"billing reconciliation run complete"
|
|
);
|
|
}
|
|
|
|
fn update_billing_gauges(state: &AppState) {
|
|
let tenant_ids = state.billing.get_all_tenant_ids();
|
|
let mut counts: BTreeMap<(String, String), u64> = BTreeMap::new();
|
|
|
|
for tenant_id in tenant_ids {
|
|
let resp = state.billing.get_for_tenant(tenant_id);
|
|
let plan = match resp.plan {
|
|
Some(Plan::Free) => "free",
|
|
Some(Plan::Pro) => "pro",
|
|
Some(Plan::Enterprise) => "enterprise",
|
|
None => "none",
|
|
}
|
|
.to_string();
|
|
|
|
let status = match resp.status {
|
|
Some(SubscriptionStatus::Active) => "active",
|
|
Some(SubscriptionStatus::Trialing) => "trialing",
|
|
Some(SubscriptionStatus::PastDue) => "past_due",
|
|
Some(SubscriptionStatus::Paused) => "paused",
|
|
Some(SubscriptionStatus::Canceled) => "canceled",
|
|
Some(SubscriptionStatus::Incomplete) => "incomplete",
|
|
None => "none",
|
|
}
|
|
.to_string();
|
|
|
|
*counts.entry((plan, status)).or_insert(0) += 1;
|
|
}
|
|
|
|
for ((plan, status), count) in counts {
|
|
metrics::gauge!("billing_tenant_status_count", "plan" => plan, "status" => status)
|
|
.set(count as f64);
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum BillingError {
|
|
#[error("provider error: {0}")]
|
|
Provider(String),
|
|
#[error("invalid configuration: {0}")]
|
|
Config(String),
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait BillingProvider: Send + Sync {
|
|
async fn create_checkout_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
plan: Plan,
|
|
return_url: String,
|
|
) -> Result<String, BillingError>;
|
|
|
|
async fn create_portal_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
return_url: String,
|
|
) -> Result<String, BillingError>;
|
|
|
|
async fn verify_webhook(
|
|
&self,
|
|
payload: &[u8],
|
|
headers: &HeaderMap,
|
|
) -> Result<BillingEvent, BillingError>;
|
|
|
|
async fn fetch_subscription(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
subscription_id: &str,
|
|
) -> Result<BillingEvent, BillingError>;
|
|
}
|
|
|
|
pub struct StripeProvider {
|
|
pub secret_key: String,
|
|
pub price_pro: String,
|
|
pub price_enterprise: String,
|
|
}
|
|
|
|
#[async_trait]
|
|
impl BillingProvider for StripeProvider {
|
|
async fn create_checkout_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
plan: Plan,
|
|
_return_url: String,
|
|
) -> Result<String, BillingError> {
|
|
let _price = match plan {
|
|
Plan::Pro => &self.price_pro,
|
|
Plan::Enterprise => &self.price_enterprise,
|
|
Plan::Free => {
|
|
return Err(BillingError::Config(
|
|
"Free plan has no checkout".to_string(),
|
|
));
|
|
}
|
|
};
|
|
|
|
// TODO: Actually call Stripe API
|
|
// For now, returning a simulated Stripe checkout URL
|
|
Ok(format!(
|
|
"https://checkout.stripe.com/pay/cs_test_{}?tenant_id={}",
|
|
Uuid::new_v4(),
|
|
tenant_id
|
|
))
|
|
}
|
|
|
|
async fn create_portal_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
_return_url: String,
|
|
) -> Result<String, BillingError> {
|
|
// TODO: Actually call Stripe API
|
|
Ok(format!(
|
|
"https://billing.stripe.com/p/session/ps_test_{}?tenant_id={}",
|
|
Uuid::new_v4(),
|
|
tenant_id
|
|
))
|
|
}
|
|
|
|
async fn verify_webhook(
|
|
&self,
|
|
_payload: &[u8],
|
|
_headers: &HeaderMap,
|
|
) -> Result<BillingEvent, BillingError> {
|
|
// TODO: Implement real Stripe signature verification
|
|
Err(BillingError::Provider("Not implemented".to_string()))
|
|
}
|
|
|
|
async fn fetch_subscription(
|
|
&self,
|
|
_tenant_id: Uuid,
|
|
_subscription_id: &str,
|
|
) -> Result<BillingEvent, BillingError> {
|
|
// TODO: Actually call Stripe API with timeout
|
|
// let client = reqwest::Client::builder().timeout(Duration::from_secs(10)).build()...
|
|
Err(BillingError::Provider("Not implemented".to_string()))
|
|
}
|
|
}
|
|
|
|
pub struct MockProvider;
|
|
|
|
#[async_trait]
|
|
impl BillingProvider for MockProvider {
|
|
async fn create_checkout_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
_plan: Plan,
|
|
_return_url: String,
|
|
) -> Result<String, BillingError> {
|
|
Ok(format!("https://mock.stripe.com/checkout/{}", tenant_id))
|
|
}
|
|
|
|
async fn create_portal_session(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
_return_url: String,
|
|
) -> Result<String, BillingError> {
|
|
Ok(format!("https://mock.stripe.com/portal/{}", tenant_id))
|
|
}
|
|
|
|
async fn verify_webhook(
|
|
&self,
|
|
payload: &[u8],
|
|
_headers: &HeaderMap,
|
|
) -> Result<BillingEvent, BillingError> {
|
|
// Mock implementation: just parse the payload as a BillingEvent
|
|
serde_json::from_slice(payload).map_err(|e| BillingError::Provider(e.to_string()))
|
|
}
|
|
|
|
async fn fetch_subscription(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
_subscription_id: &str,
|
|
) -> Result<BillingEvent, BillingError> {
|
|
// Mock implementation: return a SubscriptionUpdated event with current state
|
|
// In a real mock we might want to store expectations, but for now we just return something plausible.
|
|
Ok(BillingEvent::SubscriptionUpdated {
|
|
tenant_id,
|
|
event_id: format!("reconcile-{}", Uuid::new_v4()),
|
|
status: SubscriptionStatus::Active,
|
|
plan: Plan::Pro,
|
|
current_period_end: "2099-12-31T23:59:59Z".to_string(),
|
|
cancel_at_period_end: false,
|
|
ts_ms: SystemTime::now()
|
|
.duration_since(SystemTime::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_millis() as u64,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl MockProvider {
|
|
pub fn get_checkout_url(tenant: Uuid) -> String {
|
|
format!("https://mock.stripe.com/checkout/{}", tenant)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::env::temp_dir;
|
|
|
|
#[test]
|
|
fn test_entitlement_derivation() {
|
|
let e = Entitlements::derive(Some(&Plan::Free), Some(&SubscriptionStatus::PastDue));
|
|
assert_eq!(e.max_deployments, 1);
|
|
|
|
let e = Entitlements::derive(Some(&Plan::Pro), Some(&SubscriptionStatus::Active));
|
|
assert_eq!(e.max_deployments, 10);
|
|
assert!(e.s3_docs_enabled);
|
|
|
|
let e = Entitlements::derive(Some(&Plan::Enterprise), Some(&SubscriptionStatus::Trialing));
|
|
assert_eq!(e.max_deployments, 1000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_billing_state_roundtrip() {
|
|
let mut path = temp_dir();
|
|
path.push(format!("billing-{}.json", Uuid::new_v4()));
|
|
|
|
let store = BillingStore::new(path.clone());
|
|
let tenant_id = Uuid::new_v4();
|
|
|
|
let resp = store.get_for_tenant(tenant_id);
|
|
assert!(!resp.configured);
|
|
assert_eq!(resp.entitlements.max_deployments, 1);
|
|
|
|
let state = TenantBillingState {
|
|
provider: "mock".to_string(),
|
|
provider_customer_id: None,
|
|
provider_subscription_id: None,
|
|
provider_checkout_session_id: None,
|
|
status: Some(SubscriptionStatus::Active),
|
|
plan: Some(Plan::Pro),
|
|
current_period_end: None,
|
|
cancel_at_period_end: Some(false),
|
|
processed_webhook_event_ids: vec![],
|
|
updated_at: 0,
|
|
};
|
|
|
|
store.update_tenant_state(tenant_id, state).unwrap();
|
|
|
|
let resp2 = store.get_for_tenant(tenant_id);
|
|
assert!(resp2.configured);
|
|
assert_eq!(resp2.provider.as_deref(), Some("mock"));
|
|
assert_eq!(resp2.plan, Some(Plan::Pro));
|
|
assert_eq!(resp2.entitlements.max_deployments, 10);
|
|
|
|
let _ = fs::remove_file(path);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_reconciliation_corrects_state() {
|
|
let mut path = temp_dir();
|
|
path.push(format!("billing-reconcile-{}.json", Uuid::new_v4()));
|
|
let store = BillingStore::new(path.clone());
|
|
let tenant_id = Uuid::new_v4();
|
|
|
|
// 1. Initial state: PastDue
|
|
store
|
|
.update_tenant_state(
|
|
tenant_id,
|
|
TenantBillingState {
|
|
provider: "mock".to_string(),
|
|
provider_customer_id: Some("cus_1".to_string()),
|
|
provider_subscription_id: Some("sub_1".to_string()),
|
|
provider_checkout_session_id: None,
|
|
status: Some(SubscriptionStatus::PastDue),
|
|
plan: Some(Plan::Pro),
|
|
current_period_end: None,
|
|
cancel_at_period_end: Some(false),
|
|
processed_webhook_event_ids: vec![],
|
|
updated_at: 100,
|
|
},
|
|
)
|
|
.unwrap();
|
|
|
|
let state = AppState {
|
|
prometheus: crate::get_test_prometheus_handle(),
|
|
auth: crate::AuthConfig { hs256_secret: None },
|
|
jobs: crate::jobs::JobStore::default(),
|
|
audit: crate::AuditStore::default(),
|
|
tenant_locks: crate::job_engine::TenantLocks::default(),
|
|
config_locks: crate::job_engine::ConfigLocks::default(),
|
|
http: reqwest::Client::new(),
|
|
placement: crate::placement::PlacementStore::new(temp_dir().join("placement.json")),
|
|
billing: store.clone(),
|
|
billing_provider: Arc::new(MockProvider),
|
|
billing_enforcement_enabled: true,
|
|
config: crate::config_registry::ConfigRegistry::new(None, None),
|
|
fleet_services: vec![],
|
|
swarm: crate::swarm::SwarmStore::new(temp_dir().join("swarm.json")),
|
|
docs: None,
|
|
};
|
|
|
|
// 2. Run reconciliation. MockProvider returns Active status.
|
|
reconcile_once(&state).await;
|
|
|
|
// 3. Verify state is now Active
|
|
let resp = store.get_for_tenant(tenant_id);
|
|
assert_eq!(resp.status, Some(SubscriptionStatus::Active));
|
|
|
|
let _ = fs::remove_file(path);
|
|
}
|
|
}
|