Monorepo consolidation: workspace, shared types, transport plans, docker/swam assets
Some checks failed
ci / rust (push) Failing after 2m34s
ci / ui (push) Failing after 30s

This commit is contained in:
2026-03-30 11:40:42 +03:00
parent 7e7041cf8b
commit 1298d9a3df
246 changed files with 55434 additions and 0 deletions

1562
gateway/src/admin_iam.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,790 @@
use axum::extract::Query;
use axum::extract::State;
use axum::http::StatusCode;
use axum::Json;
use serde::Deserialize;
use serde::Serialize;
use std::time::Duration;
use crate::authz;
use crate::authz::AuthzRejection;
use crate::authz::Principal;
use crate::routing::ServiceKind;
use crate::storage::StorageError;
use crate::AppState;
pub fn router() -> axum::Router<AppState> {
axum::Router::new()
.route("/status", axum::routing::get(status))
.route("/gates", axum::routing::get(gates))
.route("/plans", axum::routing::get(list_plans))
.route("/plan", axum::routing::post(create_plan))
.route("/apply", axum::routing::post(apply_plan))
.route("/rollback", axum::routing::post(rollback_plan))
}
#[derive(Debug, Deserialize)]
pub struct ResolveQuery {
pub tenant_id: String,
pub kind: String,
}
#[derive(Debug, Serialize)]
pub struct ResolveResponse {
pub tenant_id: String,
pub kind: ServiceKind,
pub endpoint: String,
pub revision: u64,
}
#[derive(Debug, Deserialize)]
struct TenantQuery {
tenant_id: String,
}
#[derive(Debug, Serialize)]
struct StatusResponse {
tenant_id: String,
revision: u64,
aggregate: Option<String>,
projection: Option<String>,
runner: Option<String>,
}
#[derive(Debug, Serialize)]
struct GatesResponse {
tenant_id: String,
aggregate_ready: bool,
projection_ready: bool,
runner_ready: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Stored<T> {
v: u32,
data: T,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
struct RebalancePlan {
plan_id: String,
tenant_id: String,
kind: ServiceKind,
from_endpoint: Option<String>,
to_endpoint: Option<String>,
status: String,
actor_id: String,
created_at_ms: i64,
updated_at_ms: i64,
}
#[derive(Debug, Deserialize)]
struct CreatePlanBody {
tenant_id: String,
kind: String,
to_endpoint: Option<String>,
}
#[derive(Debug, Deserialize)]
struct PlanActionBody {
plan_id: String,
tenant_id: String,
}
#[derive(Debug, Deserialize)]
struct ListPlansQuery {
tenant_id: Option<String>,
limit: Option<usize>,
}
pub async fn resolve(
State(state): State<AppState>,
principal: Principal,
Query(q): Query<ResolveQuery>,
) -> Result<Json<ResolveResponse>, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
let kind = parse_kind(&q.kind).ok_or(AuthzRejection::Internal)?;
let table = state.routing.snapshot().await;
let endpoint = table.resolve(&q.tenant_id, kind).map_err(|e| match e {
crate::routing::RoutingError::UnknownTenant => AuthzRejection::NotFound,
crate::routing::RoutingError::MissingShard | crate::routing::RoutingError::EmptyShard => {
AuthzRejection::Internal
}
})?;
Ok(Json(ResolveResponse {
tenant_id: q.tenant_id,
kind,
endpoint,
revision: table.revision,
}))
}
async fn status(
State(state): State<AppState>,
principal: Principal,
Query(q): Query<TenantQuery>,
) -> Result<Json<StatusResponse>, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
let table = state.routing.snapshot().await;
let aggregate = table.resolve(&q.tenant_id, ServiceKind::Aggregate).ok();
let projection = table.resolve(&q.tenant_id, ServiceKind::Projection).ok();
let runner = table.resolve(&q.tenant_id, ServiceKind::Runner).ok();
Ok(Json(StatusResponse {
tenant_id: q.tenant_id,
revision: table.revision,
aggregate,
projection,
runner,
}))
}
async fn gates(
State(state): State<AppState>,
principal: Principal,
Query(q): Query<TenantQuery>,
) -> Result<Json<GatesResponse>, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
let projection_endpoint = state
.routing
.resolve(&q.tenant_id, ServiceKind::Projection)
.await
.ok();
let runner_endpoint = state
.routing
.resolve(&q.tenant_id, ServiceKind::Runner)
.await
.ok();
let aggregate_endpoint = state
.routing
.resolve(&q.tenant_id, ServiceKind::Aggregate)
.await
.ok();
let projection_ready = if let Some(ep) = projection_endpoint {
projection_gate_ready(&ep, &q.tenant_id)
.await
.unwrap_or(false)
} else {
false
};
let runner_ready = if let Some(ep) = runner_endpoint {
http_ready(&ep).await.unwrap_or(false)
} else {
false
};
let aggregate_ready = if let Some(ep) = aggregate_endpoint {
aggregate_ready(&ep).await.unwrap_or(false)
} else {
false
};
Ok(Json(GatesResponse {
tenant_id: q.tenant_id,
aggregate_ready,
projection_ready,
runner_ready,
}))
}
async fn http_ready(endpoint: &str) -> Result<bool, AuthzRejection> {
let url = format!("{}/ready", endpoint.trim_end_matches('/'));
let client = crate::upstream::http_client();
let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send())
.await
.map_err(|_| AuthzRejection::Internal)?
.map_err(|_| AuthzRejection::Internal)?;
Ok(resp.status().is_success())
}
async fn aggregate_ready(endpoint: &str) -> Result<bool, AuthzRejection> {
if endpoint.contains(":50051") {
let http_ep = endpoint.replace(":50051", ":8080");
return http_ready(&http_ep).await;
}
http_ready(endpoint).await
}
async fn projection_gate_ready(endpoint: &str, tenant_id: &str) -> Result<bool, AuthzRejection> {
let url = format!("{}/metrics", endpoint.trim_end_matches('/'));
let client = crate::upstream::http_client();
let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send())
.await
.map_err(|_| AuthzRejection::Internal)?
.map_err(|_| AuthzRejection::Internal)?;
if !resp.status().is_success() {
return Ok(false);
}
let text = resp.text().await.map_err(|_| AuthzRejection::Internal)?;
let ready = parse_prom_gauge(&text, "projection_ready").unwrap_or(0.0) >= 1.0;
if !ready {
return Ok(false);
}
let max_lag = parse_projection_max_lag(&text, tenant_id).unwrap_or(u64::MAX);
let threshold = std::env::var("GATEWAY_REBALANCE_PROJECTION_MAX_LAG")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
Ok(max_lag <= threshold)
}
fn parse_prom_gauge(metrics: &str, name: &str) -> Option<f64> {
for line in metrics.lines() {
let line = line.trim();
if line.starts_with('#') || line.is_empty() {
continue;
}
if line.starts_with(name) && !line.contains('{') {
let mut it = line.split_whitespace();
let _ = it.next()?;
return it.next()?.parse::<f64>().ok();
}
}
None
}
fn parse_projection_max_lag(metrics: &str, tenant_id: &str) -> Option<u64> {
let mut max: Option<u64> = None;
for line in metrics.lines() {
let line = line.trim();
if !line.starts_with("projection_lag{") {
continue;
}
if !line.contains(&format!("tenant_id=\"{}\"", tenant_id)) {
continue;
}
let value = line
.split_whitespace()
.nth(1)
.and_then(|v| v.parse::<u64>().ok())?;
max = Some(max.map(|m| m.max(value)).unwrap_or(value));
}
max
}
fn parse_kind(kind: &str) -> Option<ServiceKind> {
match kind.trim().to_ascii_lowercase().as_str() {
"aggregate" => Some(ServiceKind::Aggregate),
"projection" => Some(ServiceKind::Projection),
"runner" => Some(ServiceKind::Runner),
_ => None,
}
}
async fn require_platform_admin(
storage: &crate::storage::GatewayStorage,
principal_id: &str,
) -> Result<(), AuthzRejection> {
authz::ensure_allowed(storage, principal_id, "*", "iam.platform_admin").await
}
async fn create_plan(
State(state): State<AppState>,
principal: Principal,
Json(body): Json<CreatePlanBody>,
) -> Result<Json<RebalancePlan>, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
if body.tenant_id.trim().is_empty() {
return Err(AuthzRejection::BadRequest);
}
let kind = parse_kind(&body.kind).ok_or(AuthzRejection::BadRequest)?;
let to_endpoint = body.to_endpoint.filter(|s| !s.trim().is_empty());
if to_endpoint.is_none() {
return Err(AuthzRejection::BadRequest);
}
let from_endpoint = state.routing.resolve(&body.tenant_id, kind).await.ok();
let plan_id = uuid::Uuid::new_v4().to_string();
let now_ms = unix_ms();
let plan = RebalancePlan {
plan_id: plan_id.clone(),
tenant_id: body.tenant_id.clone(),
kind,
from_endpoint,
to_endpoint,
status: "planned".to_string(),
actor_id: principal.user_id,
created_at_ms: now_ms,
updated_at_ms: now_ms,
};
let key = plan_key(&plan.tenant_id, &plan.plan_id);
state
.storage
.audit_index
.create(
&key,
encode_stored(&plan).map_err(|_| AuthzRejection::Internal)?,
)
.await
.map_err(|e| match e {
StorageError::AlreadyExists => AuthzRejection::Conflict,
_ => AuthzRejection::Internal,
})?;
Ok(Json(plan))
}
async fn apply_plan(
State(state): State<AppState>,
principal: Principal,
Json(body): Json<PlanActionBody>,
) -> Result<StatusCode, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
transition_plan_status(&state, &body.tenant_id, &body.plan_id, "apply_requested").await?;
Ok(StatusCode::NO_CONTENT)
}
async fn rollback_plan(
State(state): State<AppState>,
principal: Principal,
Json(body): Json<PlanActionBody>,
) -> Result<StatusCode, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
transition_plan_status(&state, &body.tenant_id, &body.plan_id, "rollback_requested").await?;
Ok(StatusCode::NO_CONTENT)
}
async fn list_plans(
State(state): State<AppState>,
principal: Principal,
Query(q): Query<ListPlansQuery>,
) -> Result<Json<Vec<RebalancePlan>>, AuthzRejection> {
require_platform_admin(&state.storage, &principal.user_id).await?;
let prefix = match &q.tenant_id {
Some(t) => format!("v1/rebalance/plans/{}/", t.trim()),
None => "v1/rebalance/plans/".to_string(),
};
let mut keys = state
.storage
.audit_index
.list_keys(&prefix)
.await
.map_err(|_| AuthzRejection::Internal)?;
keys.sort();
keys.reverse();
let limit = q.limit.unwrap_or(50).min(200);
let mut out = Vec::new();
for key in keys.into_iter().take(limit) {
let entry = state
.storage
.audit_index
.get(&key)
.await
.map_err(|_| AuthzRejection::Internal)?;
let Some(entry) = entry else {
continue;
};
let plan: RebalancePlan =
decode_stored(&entry.value).map_err(|_| AuthzRejection::Internal)?;
out.push(plan);
}
Ok(Json(out))
}
async fn transition_plan_status(
state: &AppState,
tenant_id: &str,
plan_id: &str,
next_status: &str,
) -> Result<(), AuthzRejection> {
let key = plan_key(tenant_id, plan_id);
for _ in 0..10 {
let entry = state
.storage
.audit_index
.get(&key)
.await
.map_err(|_| AuthzRejection::Internal)?
.ok_or(AuthzRejection::NotFound)?;
let mut plan: Stored<RebalancePlan> =
serde_json::from_slice(&entry.value).map_err(|_| AuthzRejection::Internal)?;
plan.data.status = next_status.to_string();
plan.data.updated_at_ms = unix_ms();
let payload = serde_json::to_vec(&plan).map_err(|_| AuthzRejection::Internal)?;
match state
.storage
.audit_index
.update(&key, entry.revision, payload)
.await
{
Ok(_) => return Ok(()),
Err(StorageError::CasMismatch) => continue,
Err(_) => return Err(AuthzRejection::Internal),
}
}
Err(AuthzRejection::Internal)
}
fn plan_key(tenant_id: &str, plan_id: &str) -> String {
format!("v1/rebalance/plans/{tenant_id}/{plan_id}")
}
fn encode_stored<T: Serialize>(data: &T) -> Result<Vec<u8>, StorageError> {
serde_json::to_vec(&Stored {
v: crate::storage::SCHEMA_VERSION,
data,
})
.map_err(|e| StorageError::Serde(e.to_string()))
}
fn decode_stored<T: for<'de> Deserialize<'de>>(bytes: &[u8]) -> Result<T, StorageError> {
let stored: Stored<T> =
serde_json::from_slice(bytes).map_err(|e| StorageError::Serde(e.to_string()))?;
Ok(stored.data)
}
fn unix_ms() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::authn;
use std::collections::HashMap;
use std::sync::Arc;
use tower::util::ServiceExt;
async fn test_app_with_routing(cfg: crate::routing::RoutingConfig) -> (axum::Router, AppState) {
let metrics = crate::observability::init_metrics_for_tests();
let source: Arc<dyn crate::routing::RoutingSource> =
Arc::new(crate::routing::FixedSource::new(cfg));
let routing = crate::routing::RouterState::new(source).await.unwrap();
let storage = crate::storage::GatewayStorage::new_in_memory();
let authn_cfg = crate::authn::AuthnConfig::for_tests();
let state = crate::AppState {
metrics,
routing,
storage,
authn: authn_cfg,
};
let app = crate::app(state.clone());
(app, state)
}
async fn signup_and_token(app: &axum::Router, cfg: &authn::AuthnConfig) -> (String, String) {
let response = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/auth/signup")
.header("content-type", "application/json")
.body(axum::body::Body::from(
r#"{"email":"a@b.com","password":"password123"}"#,
))
.unwrap(),
)
.await
.unwrap();
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let created: crate::authn::AuthResponse = serde_json::from_slice(&body).unwrap();
let claims = cfg.verify_access_token(&created.access_token).unwrap();
(created.access_token, claims.sub)
}
#[tokio::test]
async fn resolve_requires_platform_admin() {
let cfg = crate::routing::RoutingConfig::empty();
let (app, state) = test_app_with_routing(cfg).await;
let (token, user_id) = signup_and_token(&app, &state.authn).await;
let resp = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/routing/resolve?tenant_id=t1&kind=aggregate")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::FORBIDDEN);
crate::authz::put_role(
&state.storage,
"role-platform-admin",
vec!["iam.platform_admin".to_string()],
)
.await
.unwrap();
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
.await
.unwrap();
let resp = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/routing/resolve?tenant_id=t1&kind=aggregate")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn status_includes_revision() {
let cfg = crate::routing::RoutingConfig {
revision: 42,
aggregate_placement: HashMap::new(),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::new(),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
};
let (app, state) = test_app_with_routing(cfg).await;
let (token, user_id) = signup_and_token(&app, &state.authn).await;
crate::authz::put_role(
&state.storage,
"role-platform-admin",
vec!["iam.platform_admin".to_string()],
)
.await
.unwrap();
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
.await
.unwrap();
let resp = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/rebalance/status?tenant_id=t1")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(value.get("revision").and_then(|v| v.as_u64()).unwrap(), 42);
}
#[tokio::test]
async fn gates_prevent_cutover_when_projection_not_ready_or_lagging() {
let metrics_not_ready = axum::Router::new().route(
"/metrics",
axum::routing::get(|| async { (axum::http::StatusCode::OK, "projection_ready 0\n") }),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, metrics_not_ready).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let endpoint = format!("http://{}", addr);
let cfg = crate::routing::RoutingConfig {
revision: 1,
aggregate_placement: HashMap::new(),
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::new(),
projection_shards: HashMap::from([("p".to_string(), vec![endpoint])]),
runner_shards: HashMap::new(),
};
let (app, state) = test_app_with_routing(cfg).await;
let (token, user_id) = signup_and_token(&app, &state.authn).await;
crate::authz::put_role(
&state.storage,
"role-platform-admin",
vec!["iam.platform_admin".to_string()],
)
.await
.unwrap();
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
.await
.unwrap();
let resp = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/rebalance/gates?tenant_id=tenant-a")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(!value
.get("projection_ready")
.and_then(|v| v.as_bool())
.unwrap());
let metrics_lagging = axum::Router::new().route(
"/metrics",
axum::routing::get(|| async {
(
axum::http::StatusCode::OK,
"projection_ready 1\nprojection_lag{tenant_id=\"tenant-a\"} 5\n",
)
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, metrics_lagging).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let endpoint = format!("http://{}", addr);
std::env::set_var("GATEWAY_REBALANCE_PROJECTION_MAX_LAG", "0");
let cfg = crate::routing::RoutingConfig {
revision: 2,
aggregate_placement: HashMap::new(),
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::new(),
projection_shards: HashMap::from([("p".to_string(), vec![endpoint])]),
runner_shards: HashMap::new(),
};
let (app, state) = test_app_with_routing(cfg).await;
crate::authz::put_role(
&state.storage,
"role-platform-admin",
vec!["iam.platform_admin".to_string()],
)
.await
.unwrap();
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
.await
.unwrap();
let resp = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/rebalance/gates?tenant_id=tenant-a")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(!value
.get("projection_ready")
.and_then(|v| v.as_bool())
.unwrap());
}
#[tokio::test]
async fn plan_endpoints_require_platform_admin_and_persist_plans() {
let cfg = crate::routing::RoutingConfig::empty();
let (app, state) = test_app_with_routing(cfg).await;
let (token, user_id) = signup_and_token(&app, &state.authn).await;
let forbidden = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/admin/rebalance/plan")
.header("authorization", format!("Bearer {token}"))
.header("content-type", "application/json")
.body(axum::body::Body::from(
r#"{"tenant_id":"tenant-a","kind":"projection","to_endpoint":"http://p"}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(forbidden.status(), axum::http::StatusCode::FORBIDDEN);
crate::authz::put_role(
&state.storage,
"role-platform-admin",
vec!["iam.platform_admin".to_string()],
)
.await
.unwrap();
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
.await
.unwrap();
let created = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/admin/rebalance/plan")
.header("authorization", format!("Bearer {token}"))
.header("content-type", "application/json")
.body(axum::body::Body::from(
r#"{"tenant_id":"tenant-a","kind":"projection","to_endpoint":"http://p"}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(created.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(created.into_body(), usize::MAX)
.await
.unwrap();
let plan: serde_json::Value = serde_json::from_slice(&body).unwrap();
let plan_id = plan.get("plan_id").and_then(|v| v.as_str()).unwrap();
let listed = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/admin/rebalance/plans?tenant_id=tenant-a&limit=10")
.header("authorization", format!("Bearer {token}"))
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(listed.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(listed.into_body(), usize::MAX)
.await
.unwrap();
let plans: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(plans
.as_array()
.unwrap()
.iter()
.any(|p| p.get("plan_id").and_then(|v| v.as_str()) == Some(plan_id)));
}
}

1707
gateway/src/authn.rs Normal file

File diff suppressed because it is too large Load Diff

839
gateway/src/authz.rs Normal file
View File

@@ -0,0 +1,839 @@
use axum::extract::FromRef;
use axum::extract::FromRequestParts;
use axum::extract::Path;
use axum::extract::Request;
use axum::extract::State;
use axum::http::header;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::response::Response;
use axum::routing::post;
use axum::Json;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use thiserror::Error;
use crate::grpc;
use crate::storage::GatewayStorage;
use crate::storage::StorageError;
use crate::AppState;
pub fn router() -> axum::Router<AppState> {
axum::Router::new()
.route(
"/commands/:aggregate_type/:aggregate_id",
post(submit_command_stub),
)
.route("/query/:view_type", post(query_stub))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Principal {
pub user_id: String,
pub session_id: String,
}
#[async_trait::async_trait]
impl<S> FromRequestParts<S> for Principal
where
S: Send + Sync,
AppState: FromRef<S>,
{
type Rejection = AuthzRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or(AuthzRejection::Unauthorized)?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or(AuthzRejection::Unauthorized)?;
let app_state = AppState::from_ref(state);
let claims = app_state.authn.verify_access_token(token).map_err(|_| {
metrics::counter!("gateway_authn_token_verify_fail_total").increment(1);
AuthzRejection::Unauthorized
})?;
tracing::Span::current().record("principal_id", claims.sub.as_str());
Ok(Self {
user_id: claims.sub,
session_id: claims.session_id,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TenantId(pub String);
#[async_trait::async_trait]
impl<S> FromRequestParts<S> for TenantId
where
S: Send + Sync,
{
type Rejection = AuthzRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let raw = parts
.headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.ok_or(AuthzRejection::MissingTenant)?;
let tenant = raw.trim();
if tenant.is_empty()
|| !tenant
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(AuthzRejection::InvalidTenant);
}
tracing::Span::current().record("tenant_id", tenant);
Ok(TenantId(tenant.to_string()))
}
}
#[derive(Debug, Error)]
pub enum AuthzRejection {
#[error("unauthorized")]
Unauthorized,
#[error("bad request")]
BadRequest,
#[error("missing x-tenant-id")]
MissingTenant,
#[error("invalid x-tenant-id")]
InvalidTenant,
#[error("forbidden")]
Forbidden,
#[error("not found")]
NotFound,
#[error("conflict")]
Conflict,
#[error("internal error")]
Internal,
}
impl IntoResponse for AuthzRejection {
fn into_response(self) -> axum::response::Response {
match self {
AuthzRejection::Unauthorized => {
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
}
AuthzRejection::BadRequest => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
AuthzRejection::MissingTenant | AuthzRejection::InvalidTenant => {
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
AuthzRejection::Forbidden => (StatusCode::FORBIDDEN, self.to_string()).into_response(),
AuthzRejection::NotFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(),
AuthzRejection::Conflict => (StatusCode::CONFLICT, self.to_string()).into_response(),
AuthzRejection::Internal => {
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoleRecord {
pub role_id: String,
pub rights: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct HttpCommandRequest {
command_id: Option<String>,
payload: Value,
metadata: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Serialize)]
struct HttpCommandResponse {
events: Vec<EventDto>,
}
#[derive(Debug, Serialize)]
struct EventDto {
event_id: String,
command_id: String,
aggregate_id: String,
aggregate_type: String,
version: u64,
event_type: String,
payload_json: String,
timestamp_rfc3339: String,
}
async fn submit_command_stub(
State(state): State<AppState>,
ctx: crate::RequestContext,
principal: Principal,
TenantId(tenant_id): TenantId,
Path((aggregate_type, aggregate_id)): Path<(String, String)>,
Json(body): Json<HttpCommandRequest>,
) -> Result<Json<HttpCommandResponse>, AuthzRejection> {
ensure_allowed(
&state.storage,
&principal.user_id,
&tenant_id,
"command.submit",
)
.await?;
let command_id = body
.command_id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let metadata = body.metadata.unwrap_or_default();
let request = grpc::proto::SubmitCommandRequest {
tenant_id: tenant_id.clone(),
command_id,
aggregate_id,
aggregate_type,
payload_json: body.payload.to_string(),
metadata,
};
let resp = grpc::submit_command_via_routing(&state.routing, request, &ctx)
.await
.map_err(|_| AuthzRejection::Internal)?;
let events = resp
.events
.into_iter()
.map(|e| EventDto {
event_id: e.event_id,
command_id: e.command_id,
aggregate_id: e.aggregate_id,
aggregate_type: e.aggregate_type,
version: e.version,
event_type: e.event_type,
payload_json: e.payload_json,
timestamp_rfc3339: e.timestamp_rfc3339,
})
.collect();
Ok(Json(HttpCommandResponse { events }))
}
async fn query_stub(
State(state): State<AppState>,
ctx: crate::RequestContext,
principal: Principal,
TenantId(tenant_id): TenantId,
Path(view_type): Path<String>,
Json(payload): Json<Value>,
) -> Result<Response, AuthzRejection> {
ensure_allowed(
&state.storage,
&principal.user_id,
&tenant_id,
"query.execute",
)
.await?;
let upstream = state
.routing
.resolve(&tenant_id, crate::routing::ServiceKind::Projection)
.await
.map_err(|_| AuthzRejection::Internal)?;
tracing::Span::current().record("upstream", upstream.as_str());
let url = format!("{}/v1/query/{}", upstream.trim_end_matches('/'), view_type);
let client = crate::upstream::http_client();
let resp = client
.post(url)
.header("x-tenant-id", tenant_id)
.header("x-correlation-id", ctx.correlation_id)
.header("traceparent", ctx.traceparent)
.json(&payload)
.send()
.await
.map_err(|_| AuthzRejection::Internal)?;
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?;
let mut out = Response::new(axum::body::Body::from(bytes));
*out.status_mut() = status;
Ok(out)
}
pub async fn runner_admin_proxy(
State(state): State<AppState>,
ctx: crate::RequestContext,
principal: Principal,
TenantId(tenant_id): TenantId,
Path(path): Path<String>,
request: Request,
) -> Result<Response, AuthzRejection> {
ensure_allowed(
&state.storage,
&principal.user_id,
&tenant_id,
"runner.admin",
)
.await?;
let upstream = state
.routing
.resolve(&tenant_id, crate::routing::ServiceKind::Runner)
.await
.map_err(|_| AuthzRejection::Internal)?;
tracing::Span::current().record("upstream", upstream.as_str());
let mut url = format!(
"{}/admin/{}",
upstream.trim_end_matches('/'),
path.trim_start_matches('/')
);
if let Some(q) = request.uri().query() {
url.push('?');
url.push_str(q);
}
let method = request.method().clone();
let headers = request.headers().clone();
let body = axum::body::to_bytes(request.into_body(), usize::MAX)
.await
.map_err(|_| AuthzRejection::Internal)?;
let client = crate::upstream::http_client();
let mut req = client
.request(method, url)
.header("x-tenant-id", tenant_id)
.header("x-correlation-id", ctx.correlation_id)
.header("traceparent", ctx.traceparent)
.body(body);
for (k, v) in headers.iter() {
if k == header::HOST {
continue;
}
req = req.header(k, v);
}
let resp = req.send().await.map_err(|_| AuthzRejection::Internal)?;
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?;
let mut out = Response::new(axum::body::Body::from(bytes));
*out.status_mut() = status;
Ok(out)
}
pub async fn ensure_allowed(
storage: &GatewayStorage,
principal_id: &str,
tenant_id: &str,
required_right: &str,
) -> Result<(), AuthzRejection> {
let mut roles = list_assigned_roles(storage, tenant_id, principal_id).await?;
roles.extend(list_assigned_roles(storage, "*", principal_id).await?);
if roles.is_empty() {
metrics::counter!(
"gateway_authz_decisions_total",
"tenant" => tenant_id.to_string(),
"right" => required_right.to_string(),
"result" => "deny"
)
.increment(1);
return Err(AuthzRejection::Forbidden);
}
for role_id in roles {
let key = role_key(&role_id);
let entry = storage
.roles
.get(&key)
.await
.map_err(|_| AuthzRejection::Internal)?;
let Some(entry) = entry else {
continue;
};
let role: RoleRecord = decode_stored(&entry.value).map_err(|_| AuthzRejection::Internal)?;
if role.rights.iter().any(|r| r == required_right) {
metrics::counter!(
"gateway_authz_decisions_total",
"tenant" => tenant_id.to_string(),
"right" => required_right.to_string(),
"result" => "allow"
)
.increment(1);
return Ok(());
}
}
metrics::counter!(
"gateway_authz_decisions_total",
"tenant" => tenant_id.to_string(),
"right" => required_right.to_string(),
"result" => "deny"
)
.increment(1);
Err(AuthzRejection::Forbidden)
}
async fn list_assigned_roles(
storage: &GatewayStorage,
tenant_id: &str,
principal_id: &str,
) -> Result<Vec<String>, AuthzRejection> {
let prefix = assignment_prefix(tenant_id, principal_id);
let keys = storage
.assignments
.list_keys(&prefix)
.await
.map_err(|_| AuthzRejection::Internal)?;
Ok(keys
.into_iter()
.filter_map(|k| k.rsplit('/').next().map(|s| s.to_string()))
.collect())
}
fn role_key(role_id: &str) -> String {
format!("v1/roles/{role_id}")
}
fn assignment_key(tenant_id: &str, principal_id: &str, role_id: &str) -> String {
format!("v1/assignments/{tenant_id}/{principal_id}/{role_id}")
}
fn assignment_prefix(tenant_id: &str, principal_id: &str) -> String {
format!("v1/assignments/{tenant_id}/{principal_id}/")
}
fn decode_stored<T: for<'de> Deserialize<'de>>(bytes: &[u8]) -> Result<T, StorageError> {
#[derive(Deserialize)]
struct Stored<T> {
data: T,
}
let stored: Stored<T> =
serde_json::from_slice(bytes).map_err(|e| StorageError::Serde(e.to_string()))?;
Ok(stored.data)
}
pub async fn put_role(
storage: &GatewayStorage,
role_id: &str,
rights: Vec<String>,
) -> Result<(), StorageError> {
let record = RoleRecord {
role_id: role_id.to_string(),
rights,
};
let payload = serde_json::to_vec(&serde_json::json!({
"v": crate::storage::SCHEMA_VERSION,
"data": record
}))
.map_err(|e| StorageError::Serde(e.to_string()))?;
storage.roles.put(&role_key(role_id), payload).await?;
Ok(())
}
pub async fn assign_role(
storage: &GatewayStorage,
tenant_id: &str,
principal_id: &str,
role_id: &str,
) -> Result<(), StorageError> {
storage
.assignments
.put(
&assignment_key(tenant_id, principal_id, role_id),
b"1".to_vec(),
)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::authn;
use std::sync::Arc;
use tower::util::ServiceExt;
async fn test_app() -> (axum::Router, AppState) {
let metrics = crate::observability::init_metrics_for_tests();
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
crate::routing::RoutingConfig::empty(),
)))
.await
.unwrap();
let storage = crate::storage::GatewayStorage::new_in_memory();
let authn_cfg = crate::authn::AuthnConfig::for_tests();
let state = crate::AppState {
metrics,
routing,
storage,
authn: authn_cfg,
};
let app = crate::app(state.clone());
(app, state)
}
async fn test_app_with_routing(cfg: crate::routing::RoutingConfig) -> (axum::Router, AppState) {
let metrics = crate::observability::init_metrics_for_tests();
let routing =
crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(cfg)))
.await
.unwrap();
let storage = crate::storage::GatewayStorage::new_in_memory();
let authn_cfg = crate::authn::AuthnConfig::for_tests();
let state = crate::AppState {
metrics,
routing,
storage,
authn: authn_cfg,
};
let app = crate::app(state.clone());
(app, state)
}
async fn signup_and_get_claims(
app: &axum::Router,
cfg: &authn::AuthnConfig,
) -> (String, authn::AccessClaims) {
let response = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/auth/signup")
.header("content-type", "application/json")
.body(axum::body::Body::from(
r#"{"email":"a@b.com","password":"password123"}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let created: crate::authn::AuthResponse = serde_json::from_slice(&body).unwrap();
let claims = cfg.verify_access_token(&created.access_token).unwrap();
(created.access_token, claims)
}
#[tokio::test]
async fn missing_tenant_header_returns_400() {
let (app, state) = test_app().await;
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
put_role(
&state.storage,
"role-command",
vec!["command.submit".to_string()],
)
.await
.unwrap();
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
.await
.unwrap();
let response = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/commands/User/u1")
.header("authorization", format!("Bearer {token}"))
.header("content-type", "application/json")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn tenant_spoofing_is_rejected() {
let (app, state) = test_app().await;
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
put_role(
&state.storage,
"role-command",
vec!["command.submit".to_string()],
)
.await
.unwrap();
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
.await
.unwrap();
let response = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/commands/User/u1")
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", "tenant-b")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"payload":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn role_assignment_enables_expected_action() {
let (app, state) = test_app().await;
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
put_role(
&state.storage,
"role-command",
vec!["command.submit".to_string()],
)
.await
.unwrap();
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
.await
.unwrap();
let response = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/commands/User/u1")
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", "tenant-a")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"payload":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn http_command_endpoint_returns_same_shape_as_grpc_response() {
use crate::grpc::proto;
use crate::routing::RoutingConfig;
use std::collections::HashMap;
#[derive(Default)]
struct Upstream;
#[async_trait::async_trait]
impl proto::command_service_server::CommandService for Upstream {
async fn submit_command(
&self,
request: tonic::Request<proto::SubmitCommandRequest>,
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
let req = request.into_inner();
Ok(tonic::Response::new(proto::SubmitCommandResponse {
events: vec![proto::Event {
event_id: "e1".to_string(),
command_id: req.command_id,
aggregate_id: req.aggregate_id,
aggregate_type: req.aggregate_type,
version: 1,
event_type: "Created".to_string(),
payload_json: "{}".to_string(),
timestamp_rfc3339: "2020-01-01T00:00:00Z".to_string(),
}],
}))
}
}
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(proto::command_service_server::CommandServiceServer::new(
Upstream,
))
.serve(addr)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let upstream_url = format!("http://{}", addr);
let cfg = RoutingConfig {
revision: 1,
aggregate_placement: HashMap::from([("tenant-a".to_string(), "a".to_string())]),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::from([("a".to_string(), vec![upstream_url])]),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
};
let (app, state) = test_app_with_routing(cfg).await;
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
put_role(
&state.storage,
"role-command",
vec!["command.submit".to_string()],
)
.await
.unwrap();
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
.await
.unwrap();
let response = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/commands/User/u1")
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", "tenant-a")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"payload":{}}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
value
.get("events")
.and_then(|v| v.as_array())
.unwrap()
.len()
== 1
);
assert_eq!(
value.get("events").unwrap()[0]
.get("event_id")
.and_then(|v| v.as_str())
.unwrap(),
"e1"
);
}
#[tokio::test]
async fn query_endpoint_denies_unauthorized_and_forwards_when_authorized() {
use crate::routing::RoutingConfig;
use std::collections::HashMap;
let projection_app = axum::Router::new().route(
"/v1/query/TestView",
post(|headers: axum::http::HeaderMap| async move {
let correlation = headers
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let traceparent = headers
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if correlation.trim().is_empty()
|| crate::trace_id_from_traceparent(traceparent).is_none()
{
return (StatusCode::BAD_REQUEST, "missing correlation");
}
(StatusCode::OK, r#"{"mode":"count"}"#)
}),
);
let projection_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let projection_addr = projection_listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(projection_listener, projection_app)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let projection_url = format!("http://{}", projection_addr);
let cfg = RoutingConfig {
revision: 1,
aggregate_placement: HashMap::new(),
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::new(),
projection_shards: HashMap::from([("p".to_string(), vec![projection_url])]),
runner_shards: HashMap::new(),
};
let (app, state) = test_app_with_routing(cfg).await;
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
let deny = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/query/TestView")
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", "tenant-a")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"uqf":"{}"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(deny.status(), StatusCode::FORBIDDEN);
put_role(
&state.storage,
"role-query",
vec!["query.execute".to_string()],
)
.await
.unwrap();
assign_role(&state.storage, "tenant-a", &claims.sub, "role-query")
.await
.unwrap();
let ok = app
.oneshot(
axum::http::Request::builder()
.method("POST")
.uri("/v1/query/TestView")
.header("authorization", format!("Bearer {token}"))
.header("x-tenant-id", "tenant-a")
.header("content-type", "application/json")
.body(axum::body::Body::from(r#"{"uqf":"{}"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(ok.status(), StatusCode::OK);
assert!(!ok
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.is_empty());
assert!(crate::trace_id_from_traceparent(
ok.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
)
.is_some());
}
}

275
gateway/src/grpc.rs Normal file
View File

@@ -0,0 +1,275 @@
use crate::routing::RouterState;
use crate::routing::RoutingError;
use crate::routing::ServiceKind;
pub mod proto {
tonic::include_proto!("aggregate.gateway.v1");
}
#[derive(Clone)]
pub struct GatewayCommandService {
routing: RouterState,
}
impl GatewayCommandService {
pub fn new(routing: RouterState) -> Self {
Self { routing }
}
}
#[async_trait::async_trait]
impl proto::command_service_server::CommandService for GatewayCommandService {
async fn submit_command(
&self,
request: tonic::Request<proto::SubmitCommandRequest>,
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
let correlation_id = request
.metadata()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let traceparent = request
.metadata()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.and_then(|s| {
if crate::trace_id_from_traceparent(s).is_some() {
Some(s.to_string())
} else {
None
}
})
.unwrap_or_else(|| {
let trace_id = uuid::Uuid::new_v4().simple().to_string();
let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
format!("00-{trace_id}-{span_id}-01")
});
let mut req = request.into_inner();
let tenant_id = req.tenant_id.trim().to_string();
if tenant_id.is_empty() {
return Err(tonic::Status::invalid_argument("tenant_id is required"));
}
req.tenant_id = tenant_id.clone();
let upstream = self
.routing
.resolve(&tenant_id, ServiceKind::Aggregate)
.await
.map_err(map_routing_error)?;
tracing::Span::current().record("upstream", upstream.as_str());
let channel = crate::upstream::grpc_endpoint(&upstream)
.map_err(|e| tonic::Status::unavailable(e.to_string()))?
.connect()
.await
.map_err(|e| tonic::Status::unavailable(e.to_string()))?;
let mut client = proto::command_service_client::CommandServiceClient::new(channel);
let mut upstream_req = tonic::Request::new(req);
if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) {
upstream_req.metadata_mut().insert("x-tenant-id", v);
}
if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) {
upstream_req.metadata_mut().insert("x-correlation-id", v);
}
if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) {
upstream_req.metadata_mut().insert("traceparent", v);
}
let mut resp = client.submit_command(upstream_req).await?;
if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) {
resp.metadata_mut().insert("x-correlation-id", v);
}
if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) {
resp.metadata_mut().insert("traceparent", v);
}
Ok(resp)
}
}
pub async fn submit_command_via_routing(
routing: &RouterState,
request: proto::SubmitCommandRequest,
ctx: &crate::RequestContext,
) -> Result<proto::SubmitCommandResponse, tonic::Status> {
let tenant_id = request.tenant_id.trim().to_string();
if tenant_id.is_empty() {
return Err(tonic::Status::invalid_argument("tenant_id is required"));
}
let upstream = routing
.resolve(&tenant_id, ServiceKind::Aggregate)
.await
.map_err(map_routing_error)?;
tracing::Span::current().record("upstream", upstream.as_str());
let channel = crate::upstream::grpc_endpoint(&upstream)
.map_err(|e| tonic::Status::unavailable(e.to_string()))?
.connect()
.await
.map_err(|e| tonic::Status::unavailable(e.to_string()))?;
let mut client = proto::command_service_client::CommandServiceClient::new(channel);
let mut upstream_req = tonic::Request::new(request);
if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) {
upstream_req.metadata_mut().insert("x-tenant-id", v);
}
if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) {
upstream_req.metadata_mut().insert("x-correlation-id", v);
}
if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) {
upstream_req.metadata_mut().insert("traceparent", v);
}
let resp = client.submit_command(upstream_req).await?;
Ok(resp.into_inner())
}
fn map_routing_error(err: RoutingError) -> tonic::Status {
match err {
RoutingError::UnknownTenant => tonic::Status::not_found("unknown tenant"),
RoutingError::MissingShard | RoutingError::EmptyShard => {
tonic::Status::unavailable(err.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::routing::RoutingConfig;
use std::collections::HashMap;
use std::sync::Arc;
#[tokio::test]
async fn grpc_submit_command_forwards_tenant_metadata_and_returns_events() {
use proto::command_service_server::CommandService;
#[derive(Default)]
struct Upstream;
#[async_trait::async_trait]
impl proto::command_service_server::CommandService for Upstream {
async fn submit_command(
&self,
request: tonic::Request<proto::SubmitCommandRequest>,
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
let tenant_md = request
.metadata()
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if tenant_md != request.get_ref().tenant_id {
return Err(tonic::Status::failed_precondition(
"missing tenant metadata",
));
}
let correlation = request
.metadata()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if correlation.trim().is_empty() {
return Err(tonic::Status::failed_precondition(
"missing correlation metadata",
));
}
let traceparent = request
.metadata()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if crate::trace_id_from_traceparent(traceparent).is_none() {
return Err(tonic::Status::failed_precondition("missing traceparent"));
}
let resp = proto::SubmitCommandResponse {
events: vec![proto::Event {
event_id: "e1".to_string(),
command_id: request.get_ref().command_id.clone(),
aggregate_id: request.get_ref().aggregate_id.clone(),
aggregate_type: request.get_ref().aggregate_type.clone(),
version: 1,
event_type: "Created".to_string(),
payload_json: "{}".to_string(),
timestamp_rfc3339: "2020-01-01T00:00:00Z".to_string(),
}],
};
Ok(tonic::Response::new(resp))
}
}
let upstream_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let upstream_addr = upstream_listener.local_addr().unwrap();
drop(upstream_listener);
let upstream_url = format!("http://{}", upstream_addr);
let upstream_task = tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(proto::command_service_server::CommandServiceServer::new(
Upstream,
))
.serve(upstream_addr)
.await
.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let cfg = RoutingConfig {
revision: 1,
aggregate_placement: HashMap::from([("tenant-a".to_string(), "a".to_string())]),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::from([("a".to_string(), vec![upstream_url])]),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
};
let routing =
crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(cfg)))
.await
.unwrap();
let svc = GatewayCommandService::new(routing);
let request = proto::SubmitCommandRequest {
tenant_id: "tenant-a".to_string(),
command_id: "c1".to_string(),
aggregate_id: "id1".to_string(),
aggregate_type: "User".to_string(),
payload_json: "{}".to_string(),
metadata: HashMap::new(),
};
let resp = CommandService::submit_command(&svc, tonic::Request::new(request))
.await
.unwrap();
assert!(!resp
.metadata()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.is_empty());
assert!(crate::trace_id_from_traceparent(
resp.metadata()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
)
.is_some());
let resp = resp.into_inner();
assert_eq!(resp.events.len(), 1);
assert_eq!(resp.events[0].command_id, "c1");
upstream_task.abort();
}
}

541
gateway/src/lib.rs Normal file
View File

@@ -0,0 +1,541 @@
use std::time::Duration;
use std::time::Instant;
use axum::error_handling::HandleErrorLayer;
use axum::extract::MatchedPath;
use axum::extract::State;
use axum::http::request::Parts;
use axum::http::HeaderName;
use axum::http::HeaderValue;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::BoxError;
use axum::Json;
use axum::Router;
use metrics_exporter_prometheus::PrometheusHandle;
use serde::Serialize;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tower::timeout::TimeoutLayer;
use tower::Layer;
use tower::Service;
use tower::ServiceBuilder;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::request_id::MakeRequestUuid;
use tower_http::request_id::PropagateRequestIdLayer;
use tower_http::request_id::SetRequestIdLayer;
use tower_http::trace::TraceLayer;
use tracing::Level;
#[derive(Debug, Clone)]
pub struct RequestContext {
pub request_id: String,
pub correlation_id: String,
pub traceparent: String,
pub trace_id: String,
}
#[async_trait::async_trait]
impl<S> axum::extract::FromRequestParts<S> for RequestContext
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let request_id = parts
.headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let correlation_id = parts
.headers
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let traceparent = parts
.headers
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let trace_id = trace_id_from_traceparent(&traceparent)
.map(|s| s.to_string())
.unwrap_or_default();
Ok(Self {
request_id,
correlation_id,
traceparent,
trace_id,
})
}
}
#[derive(Clone)]
pub struct AppState {
pub metrics: PrometheusHandle,
pub routing: routing::RouterState,
pub storage: storage::GatewayStorage,
pub authn: authn::AuthnConfig,
}
#[derive(Serialize)]
struct StatusResponse {
status: &'static str,
}
pub fn app(state: AppState) -> Router {
let request_id_header = HeaderName::from_static("x-request-id");
Router::new()
.route("/health", get(health))
.route("/ready", get(ready))
.route("/metrics", get(metrics))
.nest("/v1/auth", authn::router())
.nest("/v1", authz::router())
.nest("/admin/iam", admin_iam::router())
.nest("/v1/admin/iam", admin_iam::router())
.nest("/admin/rebalance", admin_rebalance::router())
.route("/admin/routing", get(admin_routing))
.route(
"/admin/runner/*path",
axum::routing::any(authz::runner_admin_proxy),
)
.route(
"/admin/routing/reload",
axum::routing::post(admin_routing_reload),
)
.route(
"/admin/routing/resolve",
axum::routing::get(admin_rebalance::resolve),
)
.route_layer(axum::middleware::from_fn(track_http_metrics))
.with_state(state)
.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error: BoxError| async move {
(StatusCode::REQUEST_TIMEOUT, error.to_string())
}))
.layer(SetRequestIdLayer::new(
request_id_header.clone(),
MakeRequestUuid,
))
.layer(PropagateRequestIdLayer::new(request_id_header))
.layer(EnsureCorrelationTraceLayer)
.layer(TraceLayer::new_for_http().make_span_with(
|request: &axum::http::Request<_>| {
let request_id = request
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let correlation_id = request
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let traceparent = request
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let trace_id = trace_id_from_traceparent(traceparent).unwrap_or("");
let path = request_path_for_logging(request);
tracing::span!(
Level::INFO,
"http.request",
method = %request.method(),
path = %path,
request_id = request_id,
correlation_id = correlation_id,
trace_id = trace_id,
tenant_id = tracing::field::Empty,
principal_id = tracing::field::Empty,
upstream = tracing::field::Empty,
)
},
))
.layer(RequestBodyLimitLayer::new(1024 * 1024))
.layer(TimeoutLayer::new(Duration::from_secs(30))),
)
}
#[derive(Clone)]
struct EnsureCorrelationTraceLayer;
#[derive(Clone)]
struct EnsureCorrelationTrace<S> {
inner: S,
}
impl<S> Layer<S> for EnsureCorrelationTraceLayer {
type Service = EnsureCorrelationTrace<S>;
fn layer(&self, inner: S) -> Self::Service {
Self::Service { inner }
}
}
impl<S, ReqBody, ResBody> Service<axum::http::Request<ReqBody>> for EnsureCorrelationTrace<S>
where
S: Service<axum::http::Request<ReqBody>, Response = axum::http::Response<ResBody>>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
ReqBody: Send + 'static,
ResBody: Send + 'static,
{
type Response = axum::http::Response<ResBody>;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<axum::http::Response<ResBody>, S::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: axum::http::Request<ReqBody>) -> Self::Future {
let correlation_id = req
.headers()
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.unwrap_or_else(generate_correlation_id);
let traceparent = req
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.and_then(|s| {
if trace_id_from_traceparent(s).is_some() {
Some(s.to_string())
} else {
None
}
})
.unwrap_or_else(generate_traceparent);
if let Ok(v) = HeaderValue::from_str(&correlation_id) {
req.headers_mut().insert("x-correlation-id", v);
}
if let Ok(v) = HeaderValue::from_str(&traceparent) {
req.headers_mut().insert("traceparent", v);
}
let mut inner = self.inner.clone();
Box::pin(async move {
let mut resp = inner.call(req).await?;
if resp.headers().get("x-correlation-id").is_none() {
if let Ok(v) = HeaderValue::from_str(&correlation_id) {
resp.headers_mut().insert("x-correlation-id", v);
}
}
if resp.headers().get("traceparent").is_none() {
if let Ok(v) = HeaderValue::from_str(&traceparent) {
resp.headers_mut().insert("traceparent", v);
}
}
Ok(resp)
})
}
}
fn generate_correlation_id() -> String {
uuid::Uuid::new_v4().to_string()
}
fn generate_traceparent() -> String {
let trace_id = uuid::Uuid::new_v4().simple().to_string();
let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
format!("00-{trace_id}-{span_id}-01")
}
pub(crate) fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> {
shared::trace_id_from_traceparent(traceparent)
}
async fn track_http_metrics(
req: axum::http::Request<axum::body::Body>,
next: Next,
) -> axum::response::Response {
let method = req.method().to_string();
let path = req
.extensions()
.get::<MatchedPath>()
.map(|p| p.as_str().to_string())
.unwrap_or_else(|| req.uri().path().to_string());
let start = Instant::now();
let response = next.run(req).await;
let status = response.status().as_u16().to_string();
let elapsed = start.elapsed().as_secs_f64();
metrics::counter!(
"gateway_http_requests_total",
"method" => method.clone(),
"path" => path.clone(),
"status" => status.clone()
)
.increment(1);
metrics::histogram!(
"gateway_http_request_duration_seconds",
"method" => method,
"path" => path,
"status" => status
)
.record(elapsed);
response
}
fn request_path_for_logging<B>(req: &axum::http::Request<B>) -> String {
req.extensions()
.get::<MatchedPath>()
.map(|p| p.as_str().to_string())
.unwrap_or_else(|| req.uri().path().to_string())
}
async fn health() -> impl IntoResponse {
metrics::counter!("gateway_health_requests_total").increment(1);
Json(StatusResponse { status: "ok" })
}
async fn ready() -> impl IntoResponse {
metrics::counter!("gateway_ready_requests_total").increment(1);
Json(StatusResponse { status: "ok" })
}
async fn metrics(State(state): State<AppState>) -> impl IntoResponse {
state.metrics.render()
}
async fn admin_routing(State(state): State<AppState>) -> impl IntoResponse {
Json(state.routing.snapshot().await)
}
async fn admin_routing_reload(State(state): State<AppState>) -> impl IntoResponse {
match state.routing.reload().await {
Ok(()) => StatusCode::NO_CONTENT.into_response(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
pub mod http {}
pub mod admin_iam;
pub mod admin_rebalance;
pub mod authn;
pub mod authz;
pub mod grpc;
pub mod routing;
pub mod upstream;
pub mod config {}
pub mod storage;
pub mod observability {
use edge_logger_client::Config as EdgeLoggerConfig;
use edge_logger_client::EdgeLoggerLayer;
use metrics_exporter_prometheus::PrometheusBuilder;
use metrics_exporter_prometheus::PrometheusHandle;
use std::time::Duration;
use tracing_subscriber::prelude::*;
pub fn init_tracing() {
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
let env_filter = tracing_subscriber::EnvFilter::new(filter);
let fmt_layer = tracing_subscriber::fmt::layer().json();
let edge_layer = edge_logger_layer_from_env("gateway");
let registry = tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer);
let _ = match edge_layer {
Some(layer) => registry.with(layer).try_init(),
None => registry.try_init(),
};
}
pub fn init_metrics() -> PrometheusHandle {
PrometheusBuilder::new()
.install_recorder()
.expect("failed to install Prometheus recorder")
}
pub fn init_metrics_for_tests() -> PrometheusHandle {
PrometheusBuilder::new().build_recorder().handle()
}
fn edge_logger_layer_from_env(service_name: &str) -> Option<EdgeLoggerLayer> {
let enabled = std::env::var("EDGE_LOGGER_ENABLED")
.ok()
.map(|v| matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
let socket_path = std::env::var("EDGE_LOGGER_SOCKET_PATH").ok();
if !enabled && socket_path.is_none() {
return None;
}
let environment = std::env::var("EDGE_LOGGER_ENVIRONMENT")
.or_else(|_| std::env::var("ENVIRONMENT"))
.unwrap_or_else(|_| "production".to_string());
let tenant_id =
std::env::var("EDGE_LOGGER_TENANT_ID").unwrap_or_else(|_| "default".to_string());
let batch_size = std::env::var("EDGE_LOGGER_BATCH_SIZE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(100);
let flush_interval = std::env::var("EDGE_LOGGER_FLUSH_INTERVAL_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(Duration::from_secs(1));
Some(EdgeLoggerLayer::new(EdgeLoggerConfig {
socket_path: socket_path
.unwrap_or_else(|| "/var/run/edge-logger/logger.sock".to_string()),
service: service_name.to_string(),
environment,
tenant_id,
batch_size,
flush_interval,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::OnceLock;
use tower::util::ServiceExt;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn app_state_is_send_sync() {
assert_send_sync::<AppState>();
}
#[tokio::test]
async fn health_returns_200() {
let metrics = crate::observability::init_metrics_for_tests();
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
crate::routing::RoutingConfig::empty(),
)))
.await
.unwrap();
let storage = crate::storage::GatewayStorage::new_in_memory();
let authn = crate::authn::AuthnConfig::for_tests();
let app = app(AppState {
metrics,
routing,
storage,
authn,
});
let response = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/health")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), axum::http::StatusCode::OK);
}
#[test]
fn docker_stack_yml_is_valid_yaml() {
let raw = std::fs::read_to_string("../swarm/stacks/platform.yml").unwrap();
let parsed: serde_yaml::Value = serde_yaml::from_str(&raw).unwrap();
assert!(parsed.as_mapping().is_some());
}
#[tokio::test]
async fn metrics_include_http_request_counters() {
static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
let metrics = HANDLE
.get_or_init(|| {
metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.unwrap()
})
.clone();
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
crate::routing::RoutingConfig::empty(),
)))
.await
.unwrap();
let storage = crate::storage::GatewayStorage::new_in_memory();
let authn = crate::authn::AuthnConfig::for_tests();
let app = app(AppState {
metrics,
routing,
storage,
authn,
});
let _ = app
.clone()
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/health")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
let resp = app
.oneshot(
axum::http::Request::builder()
.method("GET")
.uri("/metrics")
.body(axum::body::Body::empty())
.unwrap(),
)
.await
.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let rendered = String::from_utf8_lossy(&body);
assert!(rendered.contains("gateway_http_requests_total"));
}
#[test]
fn request_path_for_logging_does_not_include_query() {
let req = axum::http::Request::builder()
.method("GET")
.uri("/v1/auth/oidc/google/callback?code=supersecret&state=x")
.body(axum::body::Body::empty())
.unwrap();
let path = request_path_for_logging(&req);
assert_eq!(path, "/v1/auth/oidc/google/callback");
assert!(!path.contains("supersecret"));
}
}

130
gateway/src/main.rs Normal file
View File

@@ -0,0 +1,130 @@
use std::net::SocketAddr;
use std::sync::Arc;
use gateway::observability;
use gateway::routing;
use gateway::storage;
use gateway::AppState;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
observability::init_tracing();
let metrics = observability::init_metrics();
let authn = gateway::authn::AuthnConfig::from_env();
let build_version = option_env!("GATEWAY_BUILD_VERSION").unwrap_or("dev");
let build_sha = option_env!("GATEWAY_BUILD_SHA").unwrap_or("unknown");
tracing::info!(build_version, build_sha, "gateway starting");
let addr: SocketAddr = std::env::var("GATEWAY_ADDR")
.unwrap_or_else(|_| "0.0.0.0:8080".to_string())
.parse()?;
let storage_path =
std::env::var("GATEWAY_STORAGE_PATH").unwrap_or_else(|_| "./data/gateway.mdbx".to_string());
if let Some(parent) = std::path::Path::new(&storage_path).parent() {
let _ = std::fs::create_dir_all(parent);
}
let storage = storage::GatewayStorage::open_edge_storage(storage_path, "gateway")
.unwrap_or_else(|_| storage::GatewayStorage::new_in_memory());
let routing_source: Arc<dyn routing::RoutingSource> =
if let Ok(path) = std::env::var("GATEWAY_ROUTING_FILE") {
Arc::new(routing::StaticFileSource::new(path))
} else if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
std::env::var("GATEWAY_ROUTING_NATS_URL"),
std::env::var("GATEWAY_ROUTING_NATS_BUCKET"),
std::env::var("GATEWAY_ROUTING_NATS_KEY"),
) {
Arc::new(routing::NatsKvSource::connect(nats_url, bucket, key).await?)
} else {
Arc::new(routing::FixedSource::new(routing::RoutingConfig::empty()))
};
let routing = routing::RouterState::new(routing_source).await?;
let _routing_watcher = routing.start_watcher();
let grpc_addr: SocketAddr = std::env::var("GATEWAY_GRPC_ADDR")
.unwrap_or_else(|_| "0.0.0.0:8081".to_string())
.parse()?;
let state = AppState {
metrics,
routing,
storage,
authn,
};
let app = gateway::app(state.clone());
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!(%addr, "gateway listening");
tracing::info!(%grpc_addr, "gateway grpc listening");
let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(2);
let shutdown_task = {
let shutdown_tx = shutdown_tx.clone();
tokio::spawn(async move {
shutdown_signal().await;
let _ = shutdown_tx.send(());
})
};
let http_task = {
let mut shutdown_rx = shutdown_tx.subscribe();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_rx.recv().await;
})
.await
.unwrap();
})
};
let grpc_task = {
let mut shutdown_rx = shutdown_tx.subscribe();
let svc = gateway::grpc::GatewayCommandService::new(state.routing.clone());
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(
gateway::grpc::proto::command_service_server::CommandServiceServer::new(svc),
)
.serve_with_shutdown(grpc_addr, async move {
let _ = shutdown_rx.recv().await;
})
.await
.unwrap();
})
};
tokio::select! {
_ = http_task => {},
_ = grpc_task => {},
}
let _ = shutdown_task.await;
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
let _ = tokio::signal::ctrl_c().await;
};
#[cfg(unix)]
let terminate = async {
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to register SIGTERM handler");
sigterm.recv().await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
}

456
gateway/src/routing.rs Normal file
View File

@@ -0,0 +1,456 @@
use std::collections::HashMap;
use std::sync::Arc;
use futures::StreamExt;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceKind {
Aggregate,
Projection,
Runner,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RoutingConfig {
pub revision: u64,
pub aggregate_placement: HashMap<String, String>,
pub projection_placement: HashMap<String, String>,
pub runner_placement: HashMap<String, String>,
pub aggregate_shards: HashMap<String, Vec<String>>,
pub projection_shards: HashMap<String, Vec<String>>,
pub runner_shards: HashMap<String, Vec<String>>,
}
impl RoutingConfig {
pub fn empty() -> Self {
Self {
revision: 0,
aggregate_placement: HashMap::new(),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::new(),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct RoutingTable {
pub revision: u64,
aggregate_placement: HashMap<String, String>,
projection_placement: HashMap<String, String>,
runner_placement: HashMap<String, String>,
aggregate_shards: HashMap<String, Vec<String>>,
projection_shards: HashMap<String, Vec<String>>,
runner_shards: HashMap<String, Vec<String>>,
}
impl From<RoutingConfig> for RoutingTable {
fn from(value: RoutingConfig) -> Self {
Self {
revision: value.revision,
aggregate_placement: value.aggregate_placement,
projection_placement: value.projection_placement,
runner_placement: value.runner_placement,
aggregate_shards: value.aggregate_shards,
projection_shards: value.projection_shards,
runner_shards: value.runner_shards,
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum RoutingError {
#[error("unknown tenant")]
UnknownTenant,
#[error("missing shard directory entry")]
MissingShard,
#[error("no endpoints for shard")]
EmptyShard,
}
#[derive(Clone)]
pub struct RouterState {
table: Arc<tokio::sync::RwLock<Arc<RoutingTable>>>,
source: Arc<dyn RoutingSource>,
}
impl std::fmt::Debug for RouterState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterState").finish_non_exhaustive()
}
}
impl RouterState {
pub async fn new(source: Arc<dyn RoutingSource>) -> Result<Self, RoutingSourceError> {
let cfg = source.load().await?;
Ok(Self {
table: Arc::new(tokio::sync::RwLock::new(Arc::new(cfg.into()))),
source,
})
}
pub async fn snapshot(&self) -> Arc<RoutingTable> {
self.table.read().await.clone()
}
pub async fn reload(&self) -> Result<(), RoutingSourceError> {
let cfg = self.source.load().await?;
let next = Arc::new(RoutingTable::from(cfg));
*self.table.write().await = next;
Ok(())
}
pub fn start_watcher(&self) -> tokio::task::JoinHandle<()> {
let this = self.clone();
tokio::spawn(async move {
let mut stream = match this.source.watch().await {
Ok(s) => s,
Err(_) => return,
};
while let Some(msg) = stream.next().await {
if msg.is_err() {
continue;
}
let _ = this.reload().await;
}
})
}
pub async fn resolve(
&self,
tenant_id: &str,
kind: ServiceKind,
) -> Result<String, RoutingError> {
let table = self.snapshot().await;
let result = table.resolve(tenant_id, kind);
metrics::counter!(
"gateway_routing_resolutions_total",
"kind" => kind_label(kind),
"result" => if result.is_ok() { "ok" } else { "err" }
)
.increment(1);
result
}
}
fn kind_label(kind: ServiceKind) -> &'static str {
match kind {
ServiceKind::Aggregate => "aggregate",
ServiceKind::Projection => "projection",
ServiceKind::Runner => "runner",
}
}
impl RoutingTable {
pub fn resolve(&self, tenant_id: &str, kind: ServiceKind) -> Result<String, RoutingError> {
let shard_id = match kind {
ServiceKind::Aggregate => self.aggregate_placement.get(tenant_id),
ServiceKind::Projection => self.projection_placement.get(tenant_id),
ServiceKind::Runner => self.runner_placement.get(tenant_id),
}
.ok_or(RoutingError::UnknownTenant)?;
let endpoints = match kind {
ServiceKind::Aggregate => self.aggregate_shards.get(shard_id),
ServiceKind::Projection => self.projection_shards.get(shard_id),
ServiceKind::Runner => self.runner_shards.get(shard_id),
}
.ok_or(RoutingError::MissingShard)?;
endpoints.first().cloned().ok_or(RoutingError::EmptyShard)
}
}
#[derive(Debug, Error)]
pub enum RoutingSourceError {
#[error("source error: {0}")]
Source(String),
#[error("decode error: {0}")]
Decode(String),
}
#[async_trait::async_trait]
pub trait RoutingSource: Send + Sync {
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError>;
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
RoutingSourceError,
>;
}
#[derive(Clone)]
pub struct FixedSource {
cfg: RoutingConfig,
}
impl FixedSource {
pub fn new(cfg: RoutingConfig) -> Self {
Self { cfg }
}
}
#[async_trait::async_trait]
impl RoutingSource for FixedSource {
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
Ok(self.cfg.clone())
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
RoutingSourceError,
> {
Ok(Box::pin(futures::stream::empty()))
}
}
#[derive(Clone)]
pub struct StaticFileSource {
path: String,
}
impl StaticFileSource {
pub fn new(path: impl Into<String>) -> Self {
Self { path: path.into() }
}
}
#[async_trait::async_trait]
impl RoutingSource for StaticFileSource {
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
let raw = tokio::fs::read_to_string(&self.path)
.await
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
if self.path.ends_with(".json") {
serde_json::from_str::<RoutingConfig>(&raw)
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
} else {
let yaml: serde_yaml::Value = serde_yaml::from_str(&raw)
.map_err(|e| RoutingSourceError::Decode(e.to_string()))?;
let json = serde_json::to_value(yaml)
.map_err(|e| RoutingSourceError::Decode(e.to_string()))?;
serde_json::from_value::<RoutingConfig>(json)
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
}
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
RoutingSourceError,
> {
Ok(Box::pin(futures::stream::empty()))
}
}
#[derive(Clone)]
pub struct NatsKvSource {
kv: async_nats::jetstream::kv::Store,
key: String,
}
impl NatsKvSource {
pub async fn connect(
nats_url: impl Into<String>,
bucket: impl Into<String>,
key: impl Into<String>,
) -> Result<Self, RoutingSourceError> {
let nats_url = nats_url.into();
let bucket = bucket.into();
let key = key.into();
let client = async_nats::connect(nats_url)
.await
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
let jetstream = async_nats::jetstream::new(client);
let kv = match jetstream.get_key_value(&bucket).await {
Ok(kv) => kv,
Err(_) => jetstream
.create_key_value(async_nats::jetstream::kv::Config {
bucket: bucket.clone(),
..Default::default()
})
.await
.map_err(|e| RoutingSourceError::Source(e.to_string()))?,
};
Ok(Self { kv, key })
}
}
#[async_trait::async_trait]
impl RoutingSource for NatsKvSource {
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
let entry = self
.kv
.entry(&self.key)
.await
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
let Some(entry) = entry else {
return Ok(RoutingConfig::empty());
};
serde_json::from_slice::<RoutingConfig>(&entry.value)
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
RoutingSourceError,
> {
let key = self.key.clone();
let watch = self
.kv
.watch(&key)
.await
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
Ok(Box::pin(watch.filter_map(|entry| async move {
match entry {
Ok(entry) => match entry.operation {
async_nats::jetstream::kv::Operation::Put => Some(Ok(())),
async_nats::jetstream::kv::Operation::Delete
| async_nats::jetstream::kv::Operation::Purge => None,
},
Err(e) => Some(Err(RoutingSourceError::Source(e.to_string()))),
}
})))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn router_state_is_send_sync() {
assert_send_sync::<RouterState>();
}
#[tokio::test]
async fn resolves_endpoints_for_tenant_service_kind() {
let cfg = RoutingConfig {
revision: 1,
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
projection_placement: HashMap::from([("t1".to_string(), "p".to_string())]),
runner_placement: HashMap::from([("t1".to_string(), "r".to_string())]),
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a".to_string()])]),
projection_shards: HashMap::from([("p".to_string(), vec!["http://p".to_string()])]),
runner_shards: HashMap::from([("r".to_string(), vec!["http://r".to_string()])]),
};
let source: Arc<dyn RoutingSource> = Arc::new(TestSource::new(cfg));
let router = RouterState::new(source).await.unwrap();
assert_eq!(
router.resolve("t1", ServiceKind::Aggregate).await.unwrap(),
"http://a"
);
assert_eq!(
router.resolve("t1", ServiceKind::Projection).await.unwrap(),
"http://p"
);
assert_eq!(
router.resolve("t1", ServiceKind::Runner).await.unwrap(),
"http://r"
);
}
#[tokio::test]
async fn unknown_tenant_is_typed_error() {
let source: Arc<dyn RoutingSource> = Arc::new(TestSource::new(RoutingConfig::empty()));
let router = RouterState::new(source).await.unwrap();
let err = router
.resolve("missing", ServiceKind::Aggregate)
.await
.unwrap_err();
assert_eq!(err, RoutingError::UnknownTenant);
}
#[tokio::test]
async fn hot_reload_swaps_table_atomically() {
let cfg1 = RoutingConfig {
revision: 1,
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a1".to_string()])]),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
};
let cfg2 = RoutingConfig {
revision: 2,
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
projection_placement: HashMap::new(),
runner_placement: HashMap::new(),
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a2".to_string()])]),
projection_shards: HashMap::new(),
runner_shards: HashMap::new(),
};
let test_source = Arc::new(TestSource::new(cfg1));
let router = RouterState::new(test_source.clone()).await.unwrap();
let before = router.resolve("t1", ServiceKind::Aggregate).await.unwrap();
assert_eq!(before, "http://a1");
test_source.set(cfg2).await;
router.reload().await.unwrap();
let after = router.resolve("t1", ServiceKind::Aggregate).await.unwrap();
assert_eq!(after, "http://a2");
}
#[derive(Clone)]
struct TestSource {
cfg: Arc<tokio::sync::RwLock<RoutingConfig>>,
}
impl TestSource {
fn new(cfg: RoutingConfig) -> Self {
Self {
cfg: Arc::new(tokio::sync::RwLock::new(cfg)),
}
}
async fn set(&self, cfg: RoutingConfig) {
*self.cfg.write().await = cfg;
}
}
#[async_trait::async_trait]
impl RoutingSource for TestSource {
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
Ok(self.cfg.read().await.clone())
}
async fn watch(
&self,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
RoutingSourceError,
> {
Ok(Box::pin(futures::stream::empty()))
}
}
}

1015
gateway/src/storage.rs Normal file

File diff suppressed because it is too large Load Diff

99
gateway/src/upstream.rs Normal file
View File

@@ -0,0 +1,99 @@
use std::sync::OnceLock;
use std::time::Duration;
pub fn http_client() -> &'static reqwest::Client {
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
CLIENT.get_or_init(|| {
let mut builder = reqwest::Client::builder().timeout(Duration::from_secs(10));
if let Some(ca_pem) = env_or_file(
"GATEWAY_INTERNAL_CA_CERT_PEM",
"GATEWAY_INTERNAL_CA_CERT_PEM_FILE",
) {
if let Ok(cert) = reqwest::Certificate::from_pem(ca_pem.as_bytes()) {
builder = builder.add_root_certificate(cert);
}
}
if let Some(identity_pem) = env_or_file(
"GATEWAY_INTERNAL_IDENTITY_PEM",
"GATEWAY_INTERNAL_IDENTITY_PEM_FILE",
) {
if let Ok(identity) = reqwest::Identity::from_pem(identity_pem.as_bytes()) {
builder = builder.identity(identity);
}
}
builder.build().expect("failed to build reqwest client")
})
}
pub fn grpc_endpoint(url: &str) -> Result<tonic::transport::Endpoint, tonic::transport::Error> {
let mut endpoint =
tonic::transport::Endpoint::from_shared(url.to_string())?.timeout(Duration::from_secs(10));
let wants_tls = url.starts_with("https://")
|| std::env::var("GATEWAY_INTERNAL_GRPC_TLS")
.ok()
.map(|v| matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
if wants_tls {
if let Some(tls) = grpc_tls_config() {
endpoint = endpoint.tls_config(tls)?;
}
}
Ok(endpoint)
}
fn grpc_tls_config() -> Option<tonic::transport::ClientTlsConfig> {
let mut tls = tonic::transport::ClientTlsConfig::new();
let mut configured = false;
if let Some(ca_pem) = env_or_file(
"GATEWAY_INTERNAL_GRPC_CA_CERT_PEM",
"GATEWAY_INTERNAL_GRPC_CA_CERT_PEM_FILE",
) {
tls = tls.ca_certificate(tonic::transport::Certificate::from_pem(ca_pem));
configured = true;
}
let cert_pem = env_or_file(
"GATEWAY_INTERNAL_GRPC_CLIENT_CERT_PEM",
"GATEWAY_INTERNAL_GRPC_CLIENT_CERT_PEM_FILE",
);
let key_pem = env_or_file(
"GATEWAY_INTERNAL_GRPC_CLIENT_KEY_PEM",
"GATEWAY_INTERNAL_GRPC_CLIENT_KEY_PEM_FILE",
);
if let (Some(cert_pem), Some(key_pem)) = (cert_pem, key_pem) {
tls = tls.identity(tonic::transport::Identity::from_pem(cert_pem, key_pem));
configured = true;
}
if configured {
Some(tls)
} else {
None
}
}
fn env_or_file(env_key: &str, file_env_key: &str) -> Option<String> {
if let Ok(path) = std::env::var(file_env_key) {
if let Ok(raw) = std::fs::read_to_string(path) {
let trimmed = raw.trim().to_string();
if !trimmed.is_empty() {
return Some(trimmed);
}
}
}
std::env::var(env_key).ok().and_then(|v| {
let trimmed = v.trim().to_string();
if trimmed.is_empty() {
None
} else {
Some(trimmed)
}
})
}