Monorepo consolidation: workspace, shared types, transport plans, docker/swam assets
This commit is contained in:
1562
gateway/src/admin_iam.rs
Normal file
1562
gateway/src/admin_iam.rs
Normal file
File diff suppressed because it is too large
Load Diff
790
gateway/src/admin_rebalance.rs
Normal file
790
gateway/src/admin_rebalance.rs
Normal file
@@ -0,0 +1,790 @@
|
||||
use axum::extract::Query;
|
||||
use axum::extract::State;
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::authz;
|
||||
use crate::authz::AuthzRejection;
|
||||
use crate::authz::Principal;
|
||||
use crate::routing::ServiceKind;
|
||||
use crate::storage::StorageError;
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> axum::Router<AppState> {
|
||||
axum::Router::new()
|
||||
.route("/status", axum::routing::get(status))
|
||||
.route("/gates", axum::routing::get(gates))
|
||||
.route("/plans", axum::routing::get(list_plans))
|
||||
.route("/plan", axum::routing::post(create_plan))
|
||||
.route("/apply", axum::routing::post(apply_plan))
|
||||
.route("/rollback", axum::routing::post(rollback_plan))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ResolveQuery {
|
||||
pub tenant_id: String,
|
||||
pub kind: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ResolveResponse {
|
||||
pub tenant_id: String,
|
||||
pub kind: ServiceKind,
|
||||
pub endpoint: String,
|
||||
pub revision: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TenantQuery {
|
||||
tenant_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct StatusResponse {
|
||||
tenant_id: String,
|
||||
revision: u64,
|
||||
aggregate: Option<String>,
|
||||
projection: Option<String>,
|
||||
runner: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GatesResponse {
|
||||
tenant_id: String,
|
||||
aggregate_ready: bool,
|
||||
projection_ready: bool,
|
||||
runner_ready: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct Stored<T> {
|
||||
v: u32,
|
||||
data: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
struct RebalancePlan {
|
||||
plan_id: String,
|
||||
tenant_id: String,
|
||||
kind: ServiceKind,
|
||||
from_endpoint: Option<String>,
|
||||
to_endpoint: Option<String>,
|
||||
status: String,
|
||||
actor_id: String,
|
||||
created_at_ms: i64,
|
||||
updated_at_ms: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CreatePlanBody {
|
||||
tenant_id: String,
|
||||
kind: String,
|
||||
to_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct PlanActionBody {
|
||||
plan_id: String,
|
||||
tenant_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ListPlansQuery {
|
||||
tenant_id: Option<String>,
|
||||
limit: Option<usize>,
|
||||
}
|
||||
|
||||
pub async fn resolve(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Query(q): Query<ResolveQuery>,
|
||||
) -> Result<Json<ResolveResponse>, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
|
||||
let kind = parse_kind(&q.kind).ok_or(AuthzRejection::Internal)?;
|
||||
let table = state.routing.snapshot().await;
|
||||
let endpoint = table.resolve(&q.tenant_id, kind).map_err(|e| match e {
|
||||
crate::routing::RoutingError::UnknownTenant => AuthzRejection::NotFound,
|
||||
crate::routing::RoutingError::MissingShard | crate::routing::RoutingError::EmptyShard => {
|
||||
AuthzRejection::Internal
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(Json(ResolveResponse {
|
||||
tenant_id: q.tenant_id,
|
||||
kind,
|
||||
endpoint,
|
||||
revision: table.revision,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn status(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Query(q): Query<TenantQuery>,
|
||||
) -> Result<Json<StatusResponse>, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
|
||||
let table = state.routing.snapshot().await;
|
||||
let aggregate = table.resolve(&q.tenant_id, ServiceKind::Aggregate).ok();
|
||||
let projection = table.resolve(&q.tenant_id, ServiceKind::Projection).ok();
|
||||
let runner = table.resolve(&q.tenant_id, ServiceKind::Runner).ok();
|
||||
|
||||
Ok(Json(StatusResponse {
|
||||
tenant_id: q.tenant_id,
|
||||
revision: table.revision,
|
||||
aggregate,
|
||||
projection,
|
||||
runner,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn gates(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Query(q): Query<TenantQuery>,
|
||||
) -> Result<Json<GatesResponse>, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
|
||||
let projection_endpoint = state
|
||||
.routing
|
||||
.resolve(&q.tenant_id, ServiceKind::Projection)
|
||||
.await
|
||||
.ok();
|
||||
let runner_endpoint = state
|
||||
.routing
|
||||
.resolve(&q.tenant_id, ServiceKind::Runner)
|
||||
.await
|
||||
.ok();
|
||||
let aggregate_endpoint = state
|
||||
.routing
|
||||
.resolve(&q.tenant_id, ServiceKind::Aggregate)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
let projection_ready = if let Some(ep) = projection_endpoint {
|
||||
projection_gate_ready(&ep, &q.tenant_id)
|
||||
.await
|
||||
.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let runner_ready = if let Some(ep) = runner_endpoint {
|
||||
http_ready(&ep).await.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let aggregate_ready = if let Some(ep) = aggregate_endpoint {
|
||||
aggregate_ready(&ep).await.unwrap_or(false)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
Ok(Json(GatesResponse {
|
||||
tenant_id: q.tenant_id,
|
||||
aggregate_ready,
|
||||
projection_ready,
|
||||
runner_ready,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn http_ready(endpoint: &str) -> Result<bool, AuthzRejection> {
|
||||
let url = format!("{}/ready", endpoint.trim_end_matches('/'));
|
||||
let client = crate::upstream::http_client();
|
||||
let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send())
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
|
||||
async fn aggregate_ready(endpoint: &str) -> Result<bool, AuthzRejection> {
|
||||
if endpoint.contains(":50051") {
|
||||
let http_ep = endpoint.replace(":50051", ":8080");
|
||||
return http_ready(&http_ep).await;
|
||||
}
|
||||
http_ready(endpoint).await
|
||||
}
|
||||
|
||||
async fn projection_gate_ready(endpoint: &str, tenant_id: &str) -> Result<bool, AuthzRejection> {
|
||||
let url = format!("{}/metrics", endpoint.trim_end_matches('/'));
|
||||
let client = crate::upstream::http_client();
|
||||
let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send())
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
if !resp.status().is_success() {
|
||||
return Ok(false);
|
||||
}
|
||||
let text = resp.text().await.map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
let ready = parse_prom_gauge(&text, "projection_ready").unwrap_or(0.0) >= 1.0;
|
||||
if !ready {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let max_lag = parse_projection_max_lag(&text, tenant_id).unwrap_or(u64::MAX);
|
||||
let threshold = std::env::var("GATEWAY_REBALANCE_PROJECTION_MAX_LAG")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
Ok(max_lag <= threshold)
|
||||
}
|
||||
|
||||
fn parse_prom_gauge(metrics: &str, name: &str) -> Option<f64> {
|
||||
for line in metrics.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with('#') || line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if line.starts_with(name) && !line.contains('{') {
|
||||
let mut it = line.split_whitespace();
|
||||
let _ = it.next()?;
|
||||
return it.next()?.parse::<f64>().ok();
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn parse_projection_max_lag(metrics: &str, tenant_id: &str) -> Option<u64> {
|
||||
let mut max: Option<u64> = None;
|
||||
for line in metrics.lines() {
|
||||
let line = line.trim();
|
||||
if !line.starts_with("projection_lag{") {
|
||||
continue;
|
||||
}
|
||||
if !line.contains(&format!("tenant_id=\"{}\"", tenant_id)) {
|
||||
continue;
|
||||
}
|
||||
let value = line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|v| v.parse::<u64>().ok())?;
|
||||
max = Some(max.map(|m| m.max(value)).unwrap_or(value));
|
||||
}
|
||||
max
|
||||
}
|
||||
|
||||
fn parse_kind(kind: &str) -> Option<ServiceKind> {
|
||||
match kind.trim().to_ascii_lowercase().as_str() {
|
||||
"aggregate" => Some(ServiceKind::Aggregate),
|
||||
"projection" => Some(ServiceKind::Projection),
|
||||
"runner" => Some(ServiceKind::Runner),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn require_platform_admin(
|
||||
storage: &crate::storage::GatewayStorage,
|
||||
principal_id: &str,
|
||||
) -> Result<(), AuthzRejection> {
|
||||
authz::ensure_allowed(storage, principal_id, "*", "iam.platform_admin").await
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Json(body): Json<CreatePlanBody>,
|
||||
) -> Result<Json<RebalancePlan>, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
|
||||
if body.tenant_id.trim().is_empty() {
|
||||
return Err(AuthzRejection::BadRequest);
|
||||
}
|
||||
let kind = parse_kind(&body.kind).ok_or(AuthzRejection::BadRequest)?;
|
||||
let to_endpoint = body.to_endpoint.filter(|s| !s.trim().is_empty());
|
||||
if to_endpoint.is_none() {
|
||||
return Err(AuthzRejection::BadRequest);
|
||||
}
|
||||
|
||||
let from_endpoint = state.routing.resolve(&body.tenant_id, kind).await.ok();
|
||||
let plan_id = uuid::Uuid::new_v4().to_string();
|
||||
let now_ms = unix_ms();
|
||||
|
||||
let plan = RebalancePlan {
|
||||
plan_id: plan_id.clone(),
|
||||
tenant_id: body.tenant_id.clone(),
|
||||
kind,
|
||||
from_endpoint,
|
||||
to_endpoint,
|
||||
status: "planned".to_string(),
|
||||
actor_id: principal.user_id,
|
||||
created_at_ms: now_ms,
|
||||
updated_at_ms: now_ms,
|
||||
};
|
||||
|
||||
let key = plan_key(&plan.tenant_id, &plan.plan_id);
|
||||
state
|
||||
.storage
|
||||
.audit_index
|
||||
.create(
|
||||
&key,
|
||||
encode_stored(&plan).map_err(|_| AuthzRejection::Internal)?,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
StorageError::AlreadyExists => AuthzRejection::Conflict,
|
||||
_ => AuthzRejection::Internal,
|
||||
})?;
|
||||
|
||||
Ok(Json(plan))
|
||||
}
|
||||
|
||||
async fn apply_plan(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Json(body): Json<PlanActionBody>,
|
||||
) -> Result<StatusCode, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
transition_plan_status(&state, &body.tenant_id, &body.plan_id, "apply_requested").await?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn rollback_plan(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Json(body): Json<PlanActionBody>,
|
||||
) -> Result<StatusCode, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
transition_plan_status(&state, &body.tenant_id, &body.plan_id, "rollback_requested").await?;
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn list_plans(
|
||||
State(state): State<AppState>,
|
||||
principal: Principal,
|
||||
Query(q): Query<ListPlansQuery>,
|
||||
) -> Result<Json<Vec<RebalancePlan>>, AuthzRejection> {
|
||||
require_platform_admin(&state.storage, &principal.user_id).await?;
|
||||
let prefix = match &q.tenant_id {
|
||||
Some(t) => format!("v1/rebalance/plans/{}/", t.trim()),
|
||||
None => "v1/rebalance/plans/".to_string(),
|
||||
};
|
||||
let mut keys = state
|
||||
.storage
|
||||
.audit_index
|
||||
.list_keys(&prefix)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
keys.sort();
|
||||
keys.reverse();
|
||||
|
||||
let limit = q.limit.unwrap_or(50).min(200);
|
||||
let mut out = Vec::new();
|
||||
for key in keys.into_iter().take(limit) {
|
||||
let entry = state
|
||||
.storage
|
||||
.audit_index
|
||||
.get(&key)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
let Some(entry) = entry else {
|
||||
continue;
|
||||
};
|
||||
let plan: RebalancePlan =
|
||||
decode_stored(&entry.value).map_err(|_| AuthzRejection::Internal)?;
|
||||
out.push(plan);
|
||||
}
|
||||
Ok(Json(out))
|
||||
}
|
||||
|
||||
async fn transition_plan_status(
|
||||
state: &AppState,
|
||||
tenant_id: &str,
|
||||
plan_id: &str,
|
||||
next_status: &str,
|
||||
) -> Result<(), AuthzRejection> {
|
||||
let key = plan_key(tenant_id, plan_id);
|
||||
for _ in 0..10 {
|
||||
let entry = state
|
||||
.storage
|
||||
.audit_index
|
||||
.get(&key)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?
|
||||
.ok_or(AuthzRejection::NotFound)?;
|
||||
|
||||
let mut plan: Stored<RebalancePlan> =
|
||||
serde_json::from_slice(&entry.value).map_err(|_| AuthzRejection::Internal)?;
|
||||
plan.data.status = next_status.to_string();
|
||||
plan.data.updated_at_ms = unix_ms();
|
||||
let payload = serde_json::to_vec(&plan).map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
match state
|
||||
.storage
|
||||
.audit_index
|
||||
.update(&key, entry.revision, payload)
|
||||
.await
|
||||
{
|
||||
Ok(_) => return Ok(()),
|
||||
Err(StorageError::CasMismatch) => continue,
|
||||
Err(_) => return Err(AuthzRejection::Internal),
|
||||
}
|
||||
}
|
||||
Err(AuthzRejection::Internal)
|
||||
}
|
||||
|
||||
fn plan_key(tenant_id: &str, plan_id: &str) -> String {
|
||||
format!("v1/rebalance/plans/{tenant_id}/{plan_id}")
|
||||
}
|
||||
|
||||
fn encode_stored<T: Serialize>(data: &T) -> Result<Vec<u8>, StorageError> {
|
||||
serde_json::to_vec(&Stored {
|
||||
v: crate::storage::SCHEMA_VERSION,
|
||||
data,
|
||||
})
|
||||
.map_err(|e| StorageError::Serde(e.to_string()))
|
||||
}
|
||||
|
||||
fn decode_stored<T: for<'de> Deserialize<'de>>(bytes: &[u8]) -> Result<T, StorageError> {
|
||||
let stored: Stored<T> =
|
||||
serde_json::from_slice(bytes).map_err(|e| StorageError::Serde(e.to_string()))?;
|
||||
Ok(stored.data)
|
||||
}
|
||||
|
||||
fn unix_ms() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as i64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::authn;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
async fn test_app_with_routing(cfg: crate::routing::RoutingConfig) -> (axum::Router, AppState) {
|
||||
let metrics = crate::observability::init_metrics_for_tests();
|
||||
let source: Arc<dyn crate::routing::RoutingSource> =
|
||||
Arc::new(crate::routing::FixedSource::new(cfg));
|
||||
let routing = crate::routing::RouterState::new(source).await.unwrap();
|
||||
let storage = crate::storage::GatewayStorage::new_in_memory();
|
||||
let authn_cfg = crate::authn::AuthnConfig::for_tests();
|
||||
let state = crate::AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn: authn_cfg,
|
||||
};
|
||||
let app = crate::app(state.clone());
|
||||
(app, state)
|
||||
}
|
||||
|
||||
async fn signup_and_token(app: &axum::Router, cfg: &authn::AuthnConfig) -> (String, String) {
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/auth/signup")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(
|
||||
r#"{"email":"a@b.com","password":"password123"}"#,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let created: crate::authn::AuthResponse = serde_json::from_slice(&body).unwrap();
|
||||
let claims = cfg.verify_access_token(&created.access_token).unwrap();
|
||||
(created.access_token, claims.sub)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_requires_platform_admin() {
|
||||
let cfg = crate::routing::RoutingConfig::empty();
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, user_id) = signup_and_token(&app, &state.authn).await;
|
||||
|
||||
let resp = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/routing/resolve?tenant_id=t1&kind=aggregate")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::FORBIDDEN);
|
||||
|
||||
crate::authz::put_role(
|
||||
&state.storage,
|
||||
"role-platform-admin",
|
||||
vec!["iam.platform_admin".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/routing/resolve?tenant_id=t1&kind=aggregate")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn status_includes_revision() {
|
||||
let cfg = crate::routing::RoutingConfig {
|
||||
revision: 42,
|
||||
aggregate_placement: HashMap::new(),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::new(),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, user_id) = signup_and_token(&app, &state.authn).await;
|
||||
|
||||
crate::authz::put_role(
|
||||
&state.storage,
|
||||
"role-platform-admin",
|
||||
vec!["iam.platform_admin".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/rebalance/status?tenant_id=t1")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::OK);
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(value.get("revision").and_then(|v| v.as_u64()).unwrap(), 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn gates_prevent_cutover_when_projection_not_ready_or_lagging() {
|
||||
let metrics_not_ready = axum::Router::new().route(
|
||||
"/metrics",
|
||||
axum::routing::get(|| async { (axum::http::StatusCode::OK, "projection_ready 0\n") }),
|
||||
);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, metrics_not_ready).await.unwrap();
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
let endpoint = format!("http://{}", addr);
|
||||
|
||||
let cfg = crate::routing::RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::new(),
|
||||
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::new(),
|
||||
projection_shards: HashMap::from([("p".to_string(), vec![endpoint])]),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, user_id) = signup_and_token(&app, &state.authn).await;
|
||||
|
||||
crate::authz::put_role(
|
||||
&state.storage,
|
||||
"role-platform-admin",
|
||||
vec!["iam.platform_admin".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/rebalance/gates?tenant_id=tenant-a")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::OK);
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert!(!value
|
||||
.get("projection_ready")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap());
|
||||
|
||||
let metrics_lagging = axum::Router::new().route(
|
||||
"/metrics",
|
||||
axum::routing::get(|| async {
|
||||
(
|
||||
axum::http::StatusCode::OK,
|
||||
"projection_ready 1\nprojection_lag{tenant_id=\"tenant-a\"} 5\n",
|
||||
)
|
||||
}),
|
||||
);
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, metrics_lagging).await.unwrap();
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
let endpoint = format!("http://{}", addr);
|
||||
|
||||
std::env::set_var("GATEWAY_REBALANCE_PROJECTION_MAX_LAG", "0");
|
||||
|
||||
let cfg = crate::routing::RoutingConfig {
|
||||
revision: 2,
|
||||
aggregate_placement: HashMap::new(),
|
||||
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::new(),
|
||||
projection_shards: HashMap::from([("p".to_string(), vec![endpoint])]),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
crate::authz::put_role(
|
||||
&state.storage,
|
||||
"role-platform-admin",
|
||||
vec!["iam.platform_admin".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
|
||||
.await
|
||||
.unwrap();
|
||||
let resp = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/rebalance/gates?tenant_id=tenant-a")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), axum::http::StatusCode::OK);
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert!(!value
|
||||
.get("projection_ready")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn plan_endpoints_require_platform_admin_and_persist_plans() {
|
||||
let cfg = crate::routing::RoutingConfig::empty();
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, user_id) = signup_and_token(&app, &state.authn).await;
|
||||
|
||||
let forbidden = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/admin/rebalance/plan")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(
|
||||
r#"{"tenant_id":"tenant-a","kind":"projection","to_endpoint":"http://p"}"#,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(forbidden.status(), axum::http::StatusCode::FORBIDDEN);
|
||||
|
||||
crate::authz::put_role(
|
||||
&state.storage,
|
||||
"role-platform-admin",
|
||||
vec!["iam.platform_admin".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
crate::authz::assign_role(&state.storage, "*", &user_id, "role-platform-admin")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let created = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/admin/rebalance/plan")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(
|
||||
r#"{"tenant_id":"tenant-a","kind":"projection","to_endpoint":"http://p"}"#,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(created.status(), axum::http::StatusCode::OK);
|
||||
let body = axum::body::to_bytes(created.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let plan: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
let plan_id = plan.get("plan_id").and_then(|v| v.as_str()).unwrap();
|
||||
|
||||
let listed = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/admin/rebalance/plans?tenant_id=tenant-a&limit=10")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(listed.status(), axum::http::StatusCode::OK);
|
||||
let body = axum::body::to_bytes(listed.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let plans: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert!(plans
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|p| p.get("plan_id").and_then(|v| v.as_str()) == Some(plan_id)));
|
||||
}
|
||||
}
|
||||
1707
gateway/src/authn.rs
Normal file
1707
gateway/src/authn.rs
Normal file
File diff suppressed because it is too large
Load Diff
839
gateway/src/authz.rs
Normal file
839
gateway/src/authz.rs
Normal file
@@ -0,0 +1,839 @@
|
||||
use axum::extract::FromRef;
|
||||
use axum::extract::FromRequestParts;
|
||||
use axum::extract::Path;
|
||||
use axum::extract::Request;
|
||||
use axum::extract::State;
|
||||
use axum::http::header;
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::response::Response;
|
||||
use axum::routing::post;
|
||||
use axum::Json;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::grpc;
|
||||
use crate::storage::GatewayStorage;
|
||||
use crate::storage::StorageError;
|
||||
use crate::AppState;
|
||||
|
||||
pub fn router() -> axum::Router<AppState> {
|
||||
axum::Router::new()
|
||||
.route(
|
||||
"/commands/:aggregate_type/:aggregate_id",
|
||||
post(submit_command_stub),
|
||||
)
|
||||
.route("/query/:view_type", post(query_stub))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Principal {
|
||||
pub user_id: String,
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> FromRequestParts<S> for Principal
|
||||
where
|
||||
S: Send + Sync,
|
||||
AppState: FromRef<S>,
|
||||
{
|
||||
type Rejection = AuthzRejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(AuthzRejection::Unauthorized)?;
|
||||
|
||||
let token = auth_header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or(AuthzRejection::Unauthorized)?;
|
||||
|
||||
let app_state = AppState::from_ref(state);
|
||||
let claims = app_state.authn.verify_access_token(token).map_err(|_| {
|
||||
metrics::counter!("gateway_authn_token_verify_fail_total").increment(1);
|
||||
AuthzRejection::Unauthorized
|
||||
})?;
|
||||
|
||||
tracing::Span::current().record("principal_id", claims.sub.as_str());
|
||||
Ok(Self {
|
||||
user_id: claims.sub,
|
||||
session_id: claims.session_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TenantId(pub String);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> FromRequestParts<S> for TenantId
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AuthzRejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let raw = parts
|
||||
.headers
|
||||
.get("x-tenant-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or(AuthzRejection::MissingTenant)?;
|
||||
|
||||
let tenant = raw.trim();
|
||||
if tenant.is_empty()
|
||||
|| !tenant
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
|
||||
{
|
||||
return Err(AuthzRejection::InvalidTenant);
|
||||
}
|
||||
|
||||
tracing::Span::current().record("tenant_id", tenant);
|
||||
Ok(TenantId(tenant.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthzRejection {
|
||||
#[error("unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("bad request")]
|
||||
BadRequest,
|
||||
#[error("missing x-tenant-id")]
|
||||
MissingTenant,
|
||||
#[error("invalid x-tenant-id")]
|
||||
InvalidTenant,
|
||||
#[error("forbidden")]
|
||||
Forbidden,
|
||||
#[error("not found")]
|
||||
NotFound,
|
||||
#[error("conflict")]
|
||||
Conflict,
|
||||
#[error("internal error")]
|
||||
Internal,
|
||||
}
|
||||
|
||||
impl IntoResponse for AuthzRejection {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
AuthzRejection::Unauthorized => {
|
||||
(StatusCode::UNAUTHORIZED, self.to_string()).into_response()
|
||||
}
|
||||
AuthzRejection::BadRequest => {
|
||||
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
||||
}
|
||||
AuthzRejection::MissingTenant | AuthzRejection::InvalidTenant => {
|
||||
(StatusCode::BAD_REQUEST, self.to_string()).into_response()
|
||||
}
|
||||
AuthzRejection::Forbidden => (StatusCode::FORBIDDEN, self.to_string()).into_response(),
|
||||
AuthzRejection::NotFound => (StatusCode::NOT_FOUND, self.to_string()).into_response(),
|
||||
AuthzRejection::Conflict => (StatusCode::CONFLICT, self.to_string()).into_response(),
|
||||
AuthzRejection::Internal => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoleRecord {
|
||||
pub role_id: String,
|
||||
pub rights: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct HttpCommandRequest {
|
||||
command_id: Option<String>,
|
||||
payload: Value,
|
||||
metadata: Option<std::collections::HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HttpCommandResponse {
|
||||
events: Vec<EventDto>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct EventDto {
|
||||
event_id: String,
|
||||
command_id: String,
|
||||
aggregate_id: String,
|
||||
aggregate_type: String,
|
||||
version: u64,
|
||||
event_type: String,
|
||||
payload_json: String,
|
||||
timestamp_rfc3339: String,
|
||||
}
|
||||
|
||||
async fn submit_command_stub(
|
||||
State(state): State<AppState>,
|
||||
ctx: crate::RequestContext,
|
||||
principal: Principal,
|
||||
TenantId(tenant_id): TenantId,
|
||||
Path((aggregate_type, aggregate_id)): Path<(String, String)>,
|
||||
Json(body): Json<HttpCommandRequest>,
|
||||
) -> Result<Json<HttpCommandResponse>, AuthzRejection> {
|
||||
ensure_allowed(
|
||||
&state.storage,
|
||||
&principal.user_id,
|
||||
&tenant_id,
|
||||
"command.submit",
|
||||
)
|
||||
.await?;
|
||||
|
||||
let command_id = body
|
||||
.command_id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let metadata = body.metadata.unwrap_or_default();
|
||||
let request = grpc::proto::SubmitCommandRequest {
|
||||
tenant_id: tenant_id.clone(),
|
||||
command_id,
|
||||
aggregate_id,
|
||||
aggregate_type,
|
||||
payload_json: body.payload.to_string(),
|
||||
metadata,
|
||||
};
|
||||
|
||||
let resp = grpc::submit_command_via_routing(&state.routing, request, &ctx)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
let events = resp
|
||||
.events
|
||||
.into_iter()
|
||||
.map(|e| EventDto {
|
||||
event_id: e.event_id,
|
||||
command_id: e.command_id,
|
||||
aggregate_id: e.aggregate_id,
|
||||
aggregate_type: e.aggregate_type,
|
||||
version: e.version,
|
||||
event_type: e.event_type,
|
||||
payload_json: e.payload_json,
|
||||
timestamp_rfc3339: e.timestamp_rfc3339,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(HttpCommandResponse { events }))
|
||||
}
|
||||
|
||||
async fn query_stub(
|
||||
State(state): State<AppState>,
|
||||
ctx: crate::RequestContext,
|
||||
principal: Principal,
|
||||
TenantId(tenant_id): TenantId,
|
||||
Path(view_type): Path<String>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<Response, AuthzRejection> {
|
||||
ensure_allowed(
|
||||
&state.storage,
|
||||
&principal.user_id,
|
||||
&tenant_id,
|
||||
"query.execute",
|
||||
)
|
||||
.await?;
|
||||
|
||||
let upstream = state
|
||||
.routing
|
||||
.resolve(&tenant_id, crate::routing::ServiceKind::Projection)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
tracing::Span::current().record("upstream", upstream.as_str());
|
||||
|
||||
let url = format!("{}/v1/query/{}", upstream.trim_end_matches('/'), view_type);
|
||||
|
||||
let client = crate::upstream::http_client();
|
||||
let resp = client
|
||||
.post(url)
|
||||
.header("x-tenant-id", tenant_id)
|
||||
.header("x-correlation-id", ctx.correlation_id)
|
||||
.header("traceparent", ctx.traceparent)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?;
|
||||
let mut out = Response::new(axum::body::Body::from(bytes));
|
||||
*out.status_mut() = status;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub async fn runner_admin_proxy(
|
||||
State(state): State<AppState>,
|
||||
ctx: crate::RequestContext,
|
||||
principal: Principal,
|
||||
TenantId(tenant_id): TenantId,
|
||||
Path(path): Path<String>,
|
||||
request: Request,
|
||||
) -> Result<Response, AuthzRejection> {
|
||||
ensure_allowed(
|
||||
&state.storage,
|
||||
&principal.user_id,
|
||||
&tenant_id,
|
||||
"runner.admin",
|
||||
)
|
||||
.await?;
|
||||
|
||||
let upstream = state
|
||||
.routing
|
||||
.resolve(&tenant_id, crate::routing::ServiceKind::Runner)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
tracing::Span::current().record("upstream", upstream.as_str());
|
||||
|
||||
let mut url = format!(
|
||||
"{}/admin/{}",
|
||||
upstream.trim_end_matches('/'),
|
||||
path.trim_start_matches('/')
|
||||
);
|
||||
if let Some(q) = request.uri().query() {
|
||||
url.push('?');
|
||||
url.push_str(q);
|
||||
}
|
||||
|
||||
let method = request.method().clone();
|
||||
let headers = request.headers().clone();
|
||||
let body = axum::body::to_bytes(request.into_body(), usize::MAX)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
let client = crate::upstream::http_client();
|
||||
let mut req = client
|
||||
.request(method, url)
|
||||
.header("x-tenant-id", tenant_id)
|
||||
.header("x-correlation-id", ctx.correlation_id)
|
||||
.header("traceparent", ctx.traceparent)
|
||||
.body(body);
|
||||
|
||||
for (k, v) in headers.iter() {
|
||||
if k == header::HOST {
|
||||
continue;
|
||||
}
|
||||
req = req.header(k, v);
|
||||
}
|
||||
|
||||
let resp = req.send().await.map_err(|_| AuthzRejection::Internal)?;
|
||||
let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?;
|
||||
|
||||
let mut out = Response::new(axum::body::Body::from(bytes));
|
||||
*out.status_mut() = status;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub async fn ensure_allowed(
|
||||
storage: &GatewayStorage,
|
||||
principal_id: &str,
|
||||
tenant_id: &str,
|
||||
required_right: &str,
|
||||
) -> Result<(), AuthzRejection> {
|
||||
let mut roles = list_assigned_roles(storage, tenant_id, principal_id).await?;
|
||||
roles.extend(list_assigned_roles(storage, "*", principal_id).await?);
|
||||
|
||||
if roles.is_empty() {
|
||||
metrics::counter!(
|
||||
"gateway_authz_decisions_total",
|
||||
"tenant" => tenant_id.to_string(),
|
||||
"right" => required_right.to_string(),
|
||||
"result" => "deny"
|
||||
)
|
||||
.increment(1);
|
||||
return Err(AuthzRejection::Forbidden);
|
||||
}
|
||||
|
||||
for role_id in roles {
|
||||
let key = role_key(&role_id);
|
||||
let entry = storage
|
||||
.roles
|
||||
.get(&key)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
let Some(entry) = entry else {
|
||||
continue;
|
||||
};
|
||||
let role: RoleRecord = decode_stored(&entry.value).map_err(|_| AuthzRejection::Internal)?;
|
||||
if role.rights.iter().any(|r| r == required_right) {
|
||||
metrics::counter!(
|
||||
"gateway_authz_decisions_total",
|
||||
"tenant" => tenant_id.to_string(),
|
||||
"right" => required_right.to_string(),
|
||||
"result" => "allow"
|
||||
)
|
||||
.increment(1);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
metrics::counter!(
|
||||
"gateway_authz_decisions_total",
|
||||
"tenant" => tenant_id.to_string(),
|
||||
"right" => required_right.to_string(),
|
||||
"result" => "deny"
|
||||
)
|
||||
.increment(1);
|
||||
Err(AuthzRejection::Forbidden)
|
||||
}
|
||||
|
||||
async fn list_assigned_roles(
|
||||
storage: &GatewayStorage,
|
||||
tenant_id: &str,
|
||||
principal_id: &str,
|
||||
) -> Result<Vec<String>, AuthzRejection> {
|
||||
let prefix = assignment_prefix(tenant_id, principal_id);
|
||||
let keys = storage
|
||||
.assignments
|
||||
.list_keys(&prefix)
|
||||
.await
|
||||
.map_err(|_| AuthzRejection::Internal)?;
|
||||
Ok(keys
|
||||
.into_iter()
|
||||
.filter_map(|k| k.rsplit('/').next().map(|s| s.to_string()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn role_key(role_id: &str) -> String {
|
||||
format!("v1/roles/{role_id}")
|
||||
}
|
||||
|
||||
fn assignment_key(tenant_id: &str, principal_id: &str, role_id: &str) -> String {
|
||||
format!("v1/assignments/{tenant_id}/{principal_id}/{role_id}")
|
||||
}
|
||||
|
||||
fn assignment_prefix(tenant_id: &str, principal_id: &str) -> String {
|
||||
format!("v1/assignments/{tenant_id}/{principal_id}/")
|
||||
}
|
||||
|
||||
fn decode_stored<T: for<'de> Deserialize<'de>>(bytes: &[u8]) -> Result<T, StorageError> {
|
||||
#[derive(Deserialize)]
|
||||
struct Stored<T> {
|
||||
data: T,
|
||||
}
|
||||
let stored: Stored<T> =
|
||||
serde_json::from_slice(bytes).map_err(|e| StorageError::Serde(e.to_string()))?;
|
||||
Ok(stored.data)
|
||||
}
|
||||
|
||||
pub async fn put_role(
|
||||
storage: &GatewayStorage,
|
||||
role_id: &str,
|
||||
rights: Vec<String>,
|
||||
) -> Result<(), StorageError> {
|
||||
let record = RoleRecord {
|
||||
role_id: role_id.to_string(),
|
||||
rights,
|
||||
};
|
||||
let payload = serde_json::to_vec(&serde_json::json!({
|
||||
"v": crate::storage::SCHEMA_VERSION,
|
||||
"data": record
|
||||
}))
|
||||
.map_err(|e| StorageError::Serde(e.to_string()))?;
|
||||
|
||||
storage.roles.put(&role_key(role_id), payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn assign_role(
|
||||
storage: &GatewayStorage,
|
||||
tenant_id: &str,
|
||||
principal_id: &str,
|
||||
role_id: &str,
|
||||
) -> Result<(), StorageError> {
|
||||
storage
|
||||
.assignments
|
||||
.put(
|
||||
&assignment_key(tenant_id, principal_id, role_id),
|
||||
b"1".to_vec(),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::authn;
|
||||
use std::sync::Arc;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
async fn test_app() -> (axum::Router, AppState) {
|
||||
let metrics = crate::observability::init_metrics_for_tests();
|
||||
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
||||
crate::routing::RoutingConfig::empty(),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
let storage = crate::storage::GatewayStorage::new_in_memory();
|
||||
let authn_cfg = crate::authn::AuthnConfig::for_tests();
|
||||
let state = crate::AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn: authn_cfg,
|
||||
};
|
||||
let app = crate::app(state.clone());
|
||||
(app, state)
|
||||
}
|
||||
|
||||
async fn test_app_with_routing(cfg: crate::routing::RoutingConfig) -> (axum::Router, AppState) {
|
||||
let metrics = crate::observability::init_metrics_for_tests();
|
||||
let routing =
|
||||
crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(cfg)))
|
||||
.await
|
||||
.unwrap();
|
||||
let storage = crate::storage::GatewayStorage::new_in_memory();
|
||||
let authn_cfg = crate::authn::AuthnConfig::for_tests();
|
||||
let state = crate::AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn: authn_cfg,
|
||||
};
|
||||
let app = crate::app(state.clone());
|
||||
(app, state)
|
||||
}
|
||||
|
||||
async fn signup_and_get_claims(
|
||||
app: &axum::Router,
|
||||
cfg: &authn::AuthnConfig,
|
||||
) -> (String, authn::AccessClaims) {
|
||||
let response = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/auth/signup")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(
|
||||
r#"{"email":"a@b.com","password":"password123"}"#,
|
||||
))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let created: crate::authn::AuthResponse = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
let claims = cfg.verify_access_token(&created.access_token).unwrap();
|
||||
(created.access_token, claims)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_tenant_header_returns_400() {
|
||||
let (app, state) = test_app().await;
|
||||
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
|
||||
|
||||
put_role(
|
||||
&state.storage,
|
||||
"role-command",
|
||||
vec!["command.submit".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/commands/User/u1")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tenant_spoofing_is_rejected() {
|
||||
let (app, state) = test_app().await;
|
||||
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
|
||||
|
||||
put_role(
|
||||
&state.storage,
|
||||
"role-command",
|
||||
vec!["command.submit".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/commands/User/u1")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("x-tenant-id", "tenant-b")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(r#"{"payload":{}}"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn role_assignment_enables_expected_action() {
|
||||
let (app, state) = test_app().await;
|
||||
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
|
||||
|
||||
put_role(
|
||||
&state.storage,
|
||||
"role-command",
|
||||
vec!["command.submit".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/commands/User/u1")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("x-tenant-id", "tenant-a")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(r#"{"payload":{}}"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn http_command_endpoint_returns_same_shape_as_grpc_response() {
|
||||
use crate::grpc::proto;
|
||||
use crate::routing::RoutingConfig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Default)]
|
||||
struct Upstream;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl proto::command_service_server::CommandService for Upstream {
|
||||
async fn submit_command(
|
||||
&self,
|
||||
request: tonic::Request<proto::SubmitCommandRequest>,
|
||||
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
|
||||
let req = request.into_inner();
|
||||
Ok(tonic::Response::new(proto::SubmitCommandResponse {
|
||||
events: vec![proto::Event {
|
||||
event_id: "e1".to_string(),
|
||||
command_id: req.command_id,
|
||||
aggregate_id: req.aggregate_id,
|
||||
aggregate_type: req.aggregate_type,
|
||||
version: 1,
|
||||
event_type: "Created".to_string(),
|
||||
payload_json: "{}".to_string(),
|
||||
timestamp_rfc3339: "2020-01-01T00:00:00Z".to_string(),
|
||||
}],
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
drop(listener);
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(proto::command_service_server::CommandServiceServer::new(
|
||||
Upstream,
|
||||
))
|
||||
.serve(addr)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let upstream_url = format!("http://{}", addr);
|
||||
let cfg = RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::from([("tenant-a".to_string(), "a".to_string())]),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::from([("a".to_string(), vec![upstream_url])]),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
|
||||
|
||||
put_role(
|
||||
&state.storage,
|
||||
"role-command",
|
||||
vec!["command.submit".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assign_role(&state.storage, "tenant-a", &claims.sub, "role-command")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/commands/User/u1")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("x-tenant-id", "tenant-a")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(r#"{"payload":{}}"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert!(
|
||||
value
|
||||
.get("events")
|
||||
.and_then(|v| v.as_array())
|
||||
.unwrap()
|
||||
.len()
|
||||
== 1
|
||||
);
|
||||
assert_eq!(
|
||||
value.get("events").unwrap()[0]
|
||||
.get("event_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap(),
|
||||
"e1"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn query_endpoint_denies_unauthorized_and_forwards_when_authorized() {
|
||||
use crate::routing::RoutingConfig;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let projection_app = axum::Router::new().route(
|
||||
"/v1/query/TestView",
|
||||
post(|headers: axum::http::HeaderMap| async move {
|
||||
let correlation = headers
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let traceparent = headers
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if correlation.trim().is_empty()
|
||||
|| crate::trace_id_from_traceparent(traceparent).is_none()
|
||||
{
|
||||
return (StatusCode::BAD_REQUEST, "missing correlation");
|
||||
}
|
||||
(StatusCode::OK, r#"{"mode":"count"}"#)
|
||||
}),
|
||||
);
|
||||
let projection_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let projection_addr = projection_listener.local_addr().unwrap();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(projection_listener, projection_app)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
let projection_url = format!("http://{}", projection_addr);
|
||||
|
||||
let cfg = RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::new(),
|
||||
projection_placement: HashMap::from([("tenant-a".to_string(), "p".to_string())]),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::new(),
|
||||
projection_shards: HashMap::from([("p".to_string(), vec![projection_url])]),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
|
||||
let (app, state) = test_app_with_routing(cfg).await;
|
||||
let (token, claims) = signup_and_get_claims(&app, &state.authn).await;
|
||||
|
||||
let deny = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/query/TestView")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("x-tenant-id", "tenant-a")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(r#"{"uqf":"{}"}"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(deny.status(), StatusCode::FORBIDDEN);
|
||||
|
||||
put_role(
|
||||
&state.storage,
|
||||
"role-query",
|
||||
vec!["query.execute".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assign_role(&state.storage, "tenant-a", &claims.sub, "role-query")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ok = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/query/TestView")
|
||||
.header("authorization", format!("Bearer {token}"))
|
||||
.header("x-tenant-id", "tenant-a")
|
||||
.header("content-type", "application/json")
|
||||
.body(axum::body::Body::from(r#"{"uqf":"{}"}"#))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ok.status(), StatusCode::OK);
|
||||
assert!(!ok
|
||||
.headers()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.is_empty());
|
||||
assert!(crate::trace_id_from_traceparent(
|
||||
ok.headers()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
)
|
||||
.is_some());
|
||||
}
|
||||
}
|
||||
275
gateway/src/grpc.rs
Normal file
275
gateway/src/grpc.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
use crate::routing::RouterState;
|
||||
use crate::routing::RoutingError;
|
||||
use crate::routing::ServiceKind;
|
||||
|
||||
pub mod proto {
|
||||
tonic::include_proto!("aggregate.gateway.v1");
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GatewayCommandService {
|
||||
routing: RouterState,
|
||||
}
|
||||
|
||||
impl GatewayCommandService {
|
||||
pub fn new(routing: RouterState) -> Self {
|
||||
Self { routing }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl proto::command_service_server::CommandService for GatewayCommandService {
|
||||
async fn submit_command(
|
||||
&self,
|
||||
request: tonic::Request<proto::SubmitCommandRequest>,
|
||||
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
|
||||
let correlation_id = request
|
||||
.metadata()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let traceparent = request
|
||||
.metadata()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.and_then(|s| {
|
||||
if crate::trace_id_from_traceparent(s).is_some() {
|
||||
Some(s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
let trace_id = uuid::Uuid::new_v4().simple().to_string();
|
||||
let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
|
||||
format!("00-{trace_id}-{span_id}-01")
|
||||
});
|
||||
|
||||
let mut req = request.into_inner();
|
||||
|
||||
let tenant_id = req.tenant_id.trim().to_string();
|
||||
if tenant_id.is_empty() {
|
||||
return Err(tonic::Status::invalid_argument("tenant_id is required"));
|
||||
}
|
||||
req.tenant_id = tenant_id.clone();
|
||||
|
||||
let upstream = self
|
||||
.routing
|
||||
.resolve(&tenant_id, ServiceKind::Aggregate)
|
||||
.await
|
||||
.map_err(map_routing_error)?;
|
||||
tracing::Span::current().record("upstream", upstream.as_str());
|
||||
|
||||
let channel = crate::upstream::grpc_endpoint(&upstream)
|
||||
.map_err(|e| tonic::Status::unavailable(e.to_string()))?
|
||||
.connect()
|
||||
.await
|
||||
.map_err(|e| tonic::Status::unavailable(e.to_string()))?;
|
||||
let mut client = proto::command_service_client::CommandServiceClient::new(channel);
|
||||
|
||||
let mut upstream_req = tonic::Request::new(req);
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) {
|
||||
upstream_req.metadata_mut().insert("x-tenant-id", v);
|
||||
}
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) {
|
||||
upstream_req.metadata_mut().insert("x-correlation-id", v);
|
||||
}
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) {
|
||||
upstream_req.metadata_mut().insert("traceparent", v);
|
||||
}
|
||||
|
||||
let mut resp = client.submit_command(upstream_req).await?;
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) {
|
||||
resp.metadata_mut().insert("x-correlation-id", v);
|
||||
}
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) {
|
||||
resp.metadata_mut().insert("traceparent", v);
|
||||
}
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn submit_command_via_routing(
|
||||
routing: &RouterState,
|
||||
request: proto::SubmitCommandRequest,
|
||||
ctx: &crate::RequestContext,
|
||||
) -> Result<proto::SubmitCommandResponse, tonic::Status> {
|
||||
let tenant_id = request.tenant_id.trim().to_string();
|
||||
if tenant_id.is_empty() {
|
||||
return Err(tonic::Status::invalid_argument("tenant_id is required"));
|
||||
}
|
||||
|
||||
let upstream = routing
|
||||
.resolve(&tenant_id, ServiceKind::Aggregate)
|
||||
.await
|
||||
.map_err(map_routing_error)?;
|
||||
tracing::Span::current().record("upstream", upstream.as_str());
|
||||
|
||||
let channel = crate::upstream::grpc_endpoint(&upstream)
|
||||
.map_err(|e| tonic::Status::unavailable(e.to_string()))?
|
||||
.connect()
|
||||
.await
|
||||
.map_err(|e| tonic::Status::unavailable(e.to_string()))?;
|
||||
let mut client = proto::command_service_client::CommandServiceClient::new(channel);
|
||||
|
||||
let mut upstream_req = tonic::Request::new(request);
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) {
|
||||
upstream_req.metadata_mut().insert("x-tenant-id", v);
|
||||
}
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) {
|
||||
upstream_req.metadata_mut().insert("x-correlation-id", v);
|
||||
}
|
||||
if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) {
|
||||
upstream_req.metadata_mut().insert("traceparent", v);
|
||||
}
|
||||
|
||||
let resp = client.submit_command(upstream_req).await?;
|
||||
Ok(resp.into_inner())
|
||||
}
|
||||
|
||||
fn map_routing_error(err: RoutingError) -> tonic::Status {
|
||||
match err {
|
||||
RoutingError::UnknownTenant => tonic::Status::not_found("unknown tenant"),
|
||||
RoutingError::MissingShard | RoutingError::EmptyShard => {
|
||||
tonic::Status::unavailable(err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::routing::RoutingConfig;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn grpc_submit_command_forwards_tenant_metadata_and_returns_events() {
|
||||
use proto::command_service_server::CommandService;
|
||||
|
||||
#[derive(Default)]
|
||||
struct Upstream;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl proto::command_service_server::CommandService for Upstream {
|
||||
async fn submit_command(
|
||||
&self,
|
||||
request: tonic::Request<proto::SubmitCommandRequest>,
|
||||
) -> Result<tonic::Response<proto::SubmitCommandResponse>, tonic::Status> {
|
||||
let tenant_md = request
|
||||
.metadata()
|
||||
.get("x-tenant-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if tenant_md != request.get_ref().tenant_id {
|
||||
return Err(tonic::Status::failed_precondition(
|
||||
"missing tenant metadata",
|
||||
));
|
||||
}
|
||||
let correlation = request
|
||||
.metadata()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if correlation.trim().is_empty() {
|
||||
return Err(tonic::Status::failed_precondition(
|
||||
"missing correlation metadata",
|
||||
));
|
||||
}
|
||||
let traceparent = request
|
||||
.metadata()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if crate::trace_id_from_traceparent(traceparent).is_none() {
|
||||
return Err(tonic::Status::failed_precondition("missing traceparent"));
|
||||
}
|
||||
|
||||
let resp = proto::SubmitCommandResponse {
|
||||
events: vec![proto::Event {
|
||||
event_id: "e1".to_string(),
|
||||
command_id: request.get_ref().command_id.clone(),
|
||||
aggregate_id: request.get_ref().aggregate_id.clone(),
|
||||
aggregate_type: request.get_ref().aggregate_type.clone(),
|
||||
version: 1,
|
||||
event_type: "Created".to_string(),
|
||||
payload_json: "{}".to_string(),
|
||||
timestamp_rfc3339: "2020-01-01T00:00:00Z".to_string(),
|
||||
}],
|
||||
};
|
||||
Ok(tonic::Response::new(resp))
|
||||
}
|
||||
}
|
||||
|
||||
let upstream_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let upstream_addr = upstream_listener.local_addr().unwrap();
|
||||
drop(upstream_listener);
|
||||
let upstream_url = format!("http://{}", upstream_addr);
|
||||
|
||||
let upstream_task = tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(proto::command_service_server::CommandServiceServer::new(
|
||||
Upstream,
|
||||
))
|
||||
.serve(upstream_addr)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let cfg = RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::from([("tenant-a".to_string(), "a".to_string())]),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::from([("a".to_string(), vec![upstream_url])]),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
|
||||
let routing =
|
||||
crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(cfg)))
|
||||
.await
|
||||
.unwrap();
|
||||
let svc = GatewayCommandService::new(routing);
|
||||
|
||||
let request = proto::SubmitCommandRequest {
|
||||
tenant_id: "tenant-a".to_string(),
|
||||
command_id: "c1".to_string(),
|
||||
aggregate_id: "id1".to_string(),
|
||||
aggregate_type: "User".to_string(),
|
||||
payload_json: "{}".to_string(),
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
let resp = CommandService::submit_command(&svc, tonic::Request::new(request))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!resp
|
||||
.metadata()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.is_empty());
|
||||
assert!(crate::trace_id_from_traceparent(
|
||||
resp.metadata()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
)
|
||||
.is_some());
|
||||
let resp = resp.into_inner();
|
||||
|
||||
assert_eq!(resp.events.len(), 1);
|
||||
assert_eq!(resp.events[0].command_id, "c1");
|
||||
|
||||
upstream_task.abort();
|
||||
}
|
||||
}
|
||||
541
gateway/src/lib.rs
Normal file
541
gateway/src/lib.rs
Normal file
@@ -0,0 +1,541 @@
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::error_handling::HandleErrorLayer;
|
||||
use axum::extract::MatchedPath;
|
||||
use axum::extract::State;
|
||||
use axum::http::request::Parts;
|
||||
use axum::http::HeaderName;
|
||||
use axum::http::HeaderValue;
|
||||
use axum::http::StatusCode;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::get;
|
||||
use axum::BoxError;
|
||||
use axum::Json;
|
||||
use axum::Router;
|
||||
use metrics_exporter_prometheus::PrometheusHandle;
|
||||
use serde::Serialize;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tower::timeout::TimeoutLayer;
|
||||
use tower::Layer;
|
||||
use tower::Service;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::limit::RequestBodyLimitLayer;
|
||||
use tower_http::request_id::MakeRequestUuid;
|
||||
use tower_http::request_id::PropagateRequestIdLayer;
|
||||
use tower_http::request_id::SetRequestIdLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::Level;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RequestContext {
|
||||
pub request_id: String,
|
||||
pub correlation_id: String,
|
||||
pub traceparent: String,
|
||||
pub trace_id: String,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S> axum::extract::FromRequestParts<S> for RequestContext
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = StatusCode;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
let request_id = parts
|
||||
.headers
|
||||
.get("x-request-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let correlation_id = parts
|
||||
.headers
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let traceparent = parts
|
||||
.headers
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let trace_id = trace_id_from_traceparent(&traceparent)
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(Self {
|
||||
request_id,
|
||||
correlation_id,
|
||||
traceparent,
|
||||
trace_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub metrics: PrometheusHandle,
|
||||
pub routing: routing::RouterState,
|
||||
pub storage: storage::GatewayStorage,
|
||||
pub authn: authn::AuthnConfig,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StatusResponse {
|
||||
status: &'static str,
|
||||
}
|
||||
|
||||
pub fn app(state: AppState) -> Router {
|
||||
let request_id_header = HeaderName::from_static("x-request-id");
|
||||
|
||||
Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/ready", get(ready))
|
||||
.route("/metrics", get(metrics))
|
||||
.nest("/v1/auth", authn::router())
|
||||
.nest("/v1", authz::router())
|
||||
.nest("/admin/iam", admin_iam::router())
|
||||
.nest("/v1/admin/iam", admin_iam::router())
|
||||
.nest("/admin/rebalance", admin_rebalance::router())
|
||||
.route("/admin/routing", get(admin_routing))
|
||||
.route(
|
||||
"/admin/runner/*path",
|
||||
axum::routing::any(authz::runner_admin_proxy),
|
||||
)
|
||||
.route(
|
||||
"/admin/routing/reload",
|
||||
axum::routing::post(admin_routing_reload),
|
||||
)
|
||||
.route(
|
||||
"/admin/routing/resolve",
|
||||
axum::routing::get(admin_rebalance::resolve),
|
||||
)
|
||||
.route_layer(axum::middleware::from_fn(track_http_metrics))
|
||||
.with_state(state)
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|error: BoxError| async move {
|
||||
(StatusCode::REQUEST_TIMEOUT, error.to_string())
|
||||
}))
|
||||
.layer(SetRequestIdLayer::new(
|
||||
request_id_header.clone(),
|
||||
MakeRequestUuid,
|
||||
))
|
||||
.layer(PropagateRequestIdLayer::new(request_id_header))
|
||||
.layer(EnsureCorrelationTraceLayer)
|
||||
.layer(TraceLayer::new_for_http().make_span_with(
|
||||
|request: &axum::http::Request<_>| {
|
||||
let request_id = request
|
||||
.headers()
|
||||
.get("x-request-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let correlation_id = request
|
||||
.headers()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let traceparent = request
|
||||
.headers()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
let trace_id = trace_id_from_traceparent(traceparent).unwrap_or("");
|
||||
let path = request_path_for_logging(request);
|
||||
|
||||
tracing::span!(
|
||||
Level::INFO,
|
||||
"http.request",
|
||||
method = %request.method(),
|
||||
path = %path,
|
||||
request_id = request_id,
|
||||
correlation_id = correlation_id,
|
||||
trace_id = trace_id,
|
||||
tenant_id = tracing::field::Empty,
|
||||
principal_id = tracing::field::Empty,
|
||||
upstream = tracing::field::Empty,
|
||||
)
|
||||
},
|
||||
))
|
||||
.layer(RequestBodyLimitLayer::new(1024 * 1024))
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30))),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EnsureCorrelationTraceLayer;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct EnsureCorrelationTrace<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for EnsureCorrelationTraceLayer {
|
||||
type Service = EnsureCorrelationTrace<S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
Self::Service { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, ReqBody, ResBody> Service<axum::http::Request<ReqBody>> for EnsureCorrelationTrace<S>
|
||||
where
|
||||
S: Service<axum::http::Request<ReqBody>, Response = axum::http::Response<ResBody>>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ 'static,
|
||||
S::Future: Send + 'static,
|
||||
S::Error: Send + 'static,
|
||||
ReqBody: Send + 'static,
|
||||
ResBody: Send + 'static,
|
||||
{
|
||||
type Response = axum::http::Response<ResBody>;
|
||||
type Error = S::Error;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<axum::http::Response<ResBody>, S::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut req: axum::http::Request<ReqBody>) -> Self::Future {
|
||||
let correlation_id = req
|
||||
.headers()
|
||||
.get("x-correlation-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(generate_correlation_id);
|
||||
|
||||
let traceparent = req
|
||||
.headers()
|
||||
.get("traceparent")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.and_then(|s| {
|
||||
if trace_id_from_traceparent(s).is_some() {
|
||||
Some(s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(generate_traceparent);
|
||||
|
||||
if let Ok(v) = HeaderValue::from_str(&correlation_id) {
|
||||
req.headers_mut().insert("x-correlation-id", v);
|
||||
}
|
||||
if let Ok(v) = HeaderValue::from_str(&traceparent) {
|
||||
req.headers_mut().insert("traceparent", v);
|
||||
}
|
||||
|
||||
let mut inner = self.inner.clone();
|
||||
Box::pin(async move {
|
||||
let mut resp = inner.call(req).await?;
|
||||
if resp.headers().get("x-correlation-id").is_none() {
|
||||
if let Ok(v) = HeaderValue::from_str(&correlation_id) {
|
||||
resp.headers_mut().insert("x-correlation-id", v);
|
||||
}
|
||||
}
|
||||
if resp.headers().get("traceparent").is_none() {
|
||||
if let Ok(v) = HeaderValue::from_str(&traceparent) {
|
||||
resp.headers_mut().insert("traceparent", v);
|
||||
}
|
||||
}
|
||||
Ok(resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_correlation_id() -> String {
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
}
|
||||
|
||||
fn generate_traceparent() -> String {
|
||||
let trace_id = uuid::Uuid::new_v4().simple().to_string();
|
||||
let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string();
|
||||
format!("00-{trace_id}-{span_id}-01")
|
||||
}
|
||||
|
||||
pub(crate) fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> {
|
||||
shared::trace_id_from_traceparent(traceparent)
|
||||
}
|
||||
|
||||
async fn track_http_metrics(
|
||||
req: axum::http::Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> axum::response::Response {
|
||||
let method = req.method().to_string();
|
||||
let path = req
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.map(|p| p.as_str().to_string())
|
||||
.unwrap_or_else(|| req.uri().path().to_string());
|
||||
let start = Instant::now();
|
||||
|
||||
let response = next.run(req).await;
|
||||
|
||||
let status = response.status().as_u16().to_string();
|
||||
let elapsed = start.elapsed().as_secs_f64();
|
||||
|
||||
metrics::counter!(
|
||||
"gateway_http_requests_total",
|
||||
"method" => method.clone(),
|
||||
"path" => path.clone(),
|
||||
"status" => status.clone()
|
||||
)
|
||||
.increment(1);
|
||||
metrics::histogram!(
|
||||
"gateway_http_request_duration_seconds",
|
||||
"method" => method,
|
||||
"path" => path,
|
||||
"status" => status
|
||||
)
|
||||
.record(elapsed);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
fn request_path_for_logging<B>(req: &axum::http::Request<B>) -> String {
|
||||
req.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.map(|p| p.as_str().to_string())
|
||||
.unwrap_or_else(|| req.uri().path().to_string())
|
||||
}
|
||||
|
||||
async fn health() -> impl IntoResponse {
|
||||
metrics::counter!("gateway_health_requests_total").increment(1);
|
||||
Json(StatusResponse { status: "ok" })
|
||||
}
|
||||
|
||||
async fn ready() -> impl IntoResponse {
|
||||
metrics::counter!("gateway_ready_requests_total").increment(1);
|
||||
Json(StatusResponse { status: "ok" })
|
||||
}
|
||||
|
||||
async fn metrics(State(state): State<AppState>) -> impl IntoResponse {
|
||||
state.metrics.render()
|
||||
}
|
||||
|
||||
async fn admin_routing(State(state): State<AppState>) -> impl IntoResponse {
|
||||
Json(state.routing.snapshot().await)
|
||||
}
|
||||
|
||||
async fn admin_routing_reload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
match state.routing.reload().await {
|
||||
Ok(()) => StatusCode::NO_CONTENT.into_response(),
|
||||
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
pub mod http {}
|
||||
pub mod admin_iam;
|
||||
pub mod admin_rebalance;
|
||||
pub mod authn;
|
||||
pub mod authz;
|
||||
pub mod grpc;
|
||||
pub mod routing;
|
||||
pub mod upstream;
|
||||
pub mod config {}
|
||||
pub mod storage;
|
||||
|
||||
pub mod observability {
|
||||
use edge_logger_client::Config as EdgeLoggerConfig;
|
||||
use edge_logger_client::EdgeLoggerLayer;
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use metrics_exporter_prometheus::PrometheusHandle;
|
||||
use std::time::Duration;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
pub fn init_tracing() {
|
||||
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string());
|
||||
let env_filter = tracing_subscriber::EnvFilter::new(filter);
|
||||
|
||||
let fmt_layer = tracing_subscriber::fmt::layer().json();
|
||||
let edge_layer = edge_logger_layer_from_env("gateway");
|
||||
|
||||
let registry = tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt_layer);
|
||||
let _ = match edge_layer {
|
||||
Some(layer) => registry.with(layer).try_init(),
|
||||
None => registry.try_init(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn init_metrics() -> PrometheusHandle {
|
||||
PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.expect("failed to install Prometheus recorder")
|
||||
}
|
||||
|
||||
pub fn init_metrics_for_tests() -> PrometheusHandle {
|
||||
PrometheusBuilder::new().build_recorder().handle()
|
||||
}
|
||||
|
||||
fn edge_logger_layer_from_env(service_name: &str) -> Option<EdgeLoggerLayer> {
|
||||
let enabled = std::env::var("EDGE_LOGGER_ENABLED")
|
||||
.ok()
|
||||
.map(|v| matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
|
||||
let socket_path = std::env::var("EDGE_LOGGER_SOCKET_PATH").ok();
|
||||
if !enabled && socket_path.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let environment = std::env::var("EDGE_LOGGER_ENVIRONMENT")
|
||||
.or_else(|_| std::env::var("ENVIRONMENT"))
|
||||
.unwrap_or_else(|_| "production".to_string());
|
||||
|
||||
let tenant_id =
|
||||
std::env::var("EDGE_LOGGER_TENANT_ID").unwrap_or_else(|_| "default".to_string());
|
||||
|
||||
let batch_size = std::env::var("EDGE_LOGGER_BATCH_SIZE")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<usize>().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
let flush_interval = std::env::var("EDGE_LOGGER_FLUSH_INTERVAL_MS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_secs(1));
|
||||
|
||||
Some(EdgeLoggerLayer::new(EdgeLoggerConfig {
|
||||
socket_path: socket_path
|
||||
.unwrap_or_else(|| "/var/run/edge-logger/logger.sock".to_string()),
|
||||
service: service_name.to_string(),
|
||||
environment,
|
||||
tenant_id,
|
||||
batch_size,
|
||||
flush_interval,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
|
||||
#[test]
|
||||
fn app_state_is_send_sync() {
|
||||
assert_send_sync::<AppState>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_returns_200() {
|
||||
let metrics = crate::observability::init_metrics_for_tests();
|
||||
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
||||
crate::routing::RoutingConfig::empty(),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
let storage = crate::storage::GatewayStorage::new_in_memory();
|
||||
let authn = crate::authn::AuthnConfig::for_tests();
|
||||
let app = app(AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn,
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/health")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), axum::http::StatusCode::OK);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn docker_stack_yml_is_valid_yaml() {
|
||||
let raw = std::fs::read_to_string("../swarm/stacks/platform.yml").unwrap();
|
||||
let parsed: serde_yaml::Value = serde_yaml::from_str(&raw).unwrap();
|
||||
assert!(parsed.as_mapping().is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn metrics_include_http_request_counters() {
|
||||
static HANDLE: OnceLock<PrometheusHandle> = OnceLock::new();
|
||||
let metrics = HANDLE
|
||||
.get_or_init(|| {
|
||||
metrics_exporter_prometheus::PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.unwrap()
|
||||
})
|
||||
.clone();
|
||||
|
||||
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
||||
crate::routing::RoutingConfig::empty(),
|
||||
)))
|
||||
.await
|
||||
.unwrap();
|
||||
let storage = crate::storage::GatewayStorage::new_in_memory();
|
||||
let authn = crate::authn::AuthnConfig::for_tests();
|
||||
let app = app(AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn,
|
||||
});
|
||||
|
||||
let _ = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/health")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let resp = app
|
||||
.oneshot(
|
||||
axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/metrics")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let rendered = String::from_utf8_lossy(&body);
|
||||
assert!(rendered.contains("gateway_http_requests_total"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_path_for_logging_does_not_include_query() {
|
||||
let req = axum::http::Request::builder()
|
||||
.method("GET")
|
||||
.uri("/v1/auth/oidc/google/callback?code=supersecret&state=x")
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
let path = request_path_for_logging(&req);
|
||||
assert_eq!(path, "/v1/auth/oidc/google/callback");
|
||||
assert!(!path.contains("supersecret"));
|
||||
}
|
||||
}
|
||||
130
gateway/src/main.rs
Normal file
130
gateway/src/main.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use gateway::observability;
|
||||
use gateway::routing;
|
||||
use gateway::storage;
|
||||
use gateway::AppState;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
observability::init_tracing();
|
||||
let metrics = observability::init_metrics();
|
||||
let authn = gateway::authn::AuthnConfig::from_env();
|
||||
|
||||
let build_version = option_env!("GATEWAY_BUILD_VERSION").unwrap_or("dev");
|
||||
let build_sha = option_env!("GATEWAY_BUILD_SHA").unwrap_or("unknown");
|
||||
tracing::info!(build_version, build_sha, "gateway starting");
|
||||
|
||||
let addr: SocketAddr = std::env::var("GATEWAY_ADDR")
|
||||
.unwrap_or_else(|_| "0.0.0.0:8080".to_string())
|
||||
.parse()?;
|
||||
|
||||
let storage_path =
|
||||
std::env::var("GATEWAY_STORAGE_PATH").unwrap_or_else(|_| "./data/gateway.mdbx".to_string());
|
||||
if let Some(parent) = std::path::Path::new(&storage_path).parent() {
|
||||
let _ = std::fs::create_dir_all(parent);
|
||||
}
|
||||
let storage = storage::GatewayStorage::open_edge_storage(storage_path, "gateway")
|
||||
.unwrap_or_else(|_| storage::GatewayStorage::new_in_memory());
|
||||
|
||||
let routing_source: Arc<dyn routing::RoutingSource> =
|
||||
if let Ok(path) = std::env::var("GATEWAY_ROUTING_FILE") {
|
||||
Arc::new(routing::StaticFileSource::new(path))
|
||||
} else if let (Ok(nats_url), Ok(bucket), Ok(key)) = (
|
||||
std::env::var("GATEWAY_ROUTING_NATS_URL"),
|
||||
std::env::var("GATEWAY_ROUTING_NATS_BUCKET"),
|
||||
std::env::var("GATEWAY_ROUTING_NATS_KEY"),
|
||||
) {
|
||||
Arc::new(routing::NatsKvSource::connect(nats_url, bucket, key).await?)
|
||||
} else {
|
||||
Arc::new(routing::FixedSource::new(routing::RoutingConfig::empty()))
|
||||
};
|
||||
|
||||
let routing = routing::RouterState::new(routing_source).await?;
|
||||
let _routing_watcher = routing.start_watcher();
|
||||
|
||||
let grpc_addr: SocketAddr = std::env::var("GATEWAY_GRPC_ADDR")
|
||||
.unwrap_or_else(|_| "0.0.0.0:8081".to_string())
|
||||
.parse()?;
|
||||
|
||||
let state = AppState {
|
||||
metrics,
|
||||
routing,
|
||||
storage,
|
||||
authn,
|
||||
};
|
||||
|
||||
let app = gateway::app(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
tracing::info!(%addr, "gateway listening");
|
||||
|
||||
tracing::info!(%grpc_addr, "gateway grpc listening");
|
||||
|
||||
let (shutdown_tx, _shutdown_rx) = tokio::sync::broadcast::channel::<()>(2);
|
||||
let shutdown_task = {
|
||||
let shutdown_tx = shutdown_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
shutdown_signal().await;
|
||||
let _ = shutdown_tx.send(());
|
||||
})
|
||||
};
|
||||
|
||||
let http_task = {
|
||||
let mut shutdown_rx = shutdown_tx.subscribe();
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(async move {
|
||||
let _ = shutdown_rx.recv().await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
};
|
||||
|
||||
let grpc_task = {
|
||||
let mut shutdown_rx = shutdown_tx.subscribe();
|
||||
let svc = gateway::grpc::GatewayCommandService::new(state.routing.clone());
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(
|
||||
gateway::grpc::proto::command_service_server::CommandServiceServer::new(svc),
|
||||
)
|
||||
.serve_with_shutdown(grpc_addr, async move {
|
||||
let _ = shutdown_rx.recv().await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = http_task => {},
|
||||
_ = grpc_task => {},
|
||||
}
|
||||
let _ = shutdown_task.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
let _ = tokio::signal::ctrl_c().await;
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||
.expect("failed to register SIGTERM handler");
|
||||
sigterm.recv().await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
456
gateway/src/routing.rs
Normal file
456
gateway/src/routing.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServiceKind {
|
||||
Aggregate,
|
||||
Projection,
|
||||
Runner,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RoutingConfig {
|
||||
pub revision: u64,
|
||||
|
||||
pub aggregate_placement: HashMap<String, String>,
|
||||
pub projection_placement: HashMap<String, String>,
|
||||
pub runner_placement: HashMap<String, String>,
|
||||
|
||||
pub aggregate_shards: HashMap<String, Vec<String>>,
|
||||
pub projection_shards: HashMap<String, Vec<String>>,
|
||||
pub runner_shards: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl RoutingConfig {
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
revision: 0,
|
||||
aggregate_placement: HashMap::new(),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::new(),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RoutingTable {
|
||||
pub revision: u64,
|
||||
aggregate_placement: HashMap<String, String>,
|
||||
projection_placement: HashMap<String, String>,
|
||||
runner_placement: HashMap<String, String>,
|
||||
aggregate_shards: HashMap<String, Vec<String>>,
|
||||
projection_shards: HashMap<String, Vec<String>>,
|
||||
runner_shards: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl From<RoutingConfig> for RoutingTable {
|
||||
fn from(value: RoutingConfig) -> Self {
|
||||
Self {
|
||||
revision: value.revision,
|
||||
aggregate_placement: value.aggregate_placement,
|
||||
projection_placement: value.projection_placement,
|
||||
runner_placement: value.runner_placement,
|
||||
aggregate_shards: value.aggregate_shards,
|
||||
projection_shards: value.projection_shards,
|
||||
runner_shards: value.runner_shards,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
||||
pub enum RoutingError {
|
||||
#[error("unknown tenant")]
|
||||
UnknownTenant,
|
||||
#[error("missing shard directory entry")]
|
||||
MissingShard,
|
||||
#[error("no endpoints for shard")]
|
||||
EmptyShard,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RouterState {
|
||||
table: Arc<tokio::sync::RwLock<Arc<RoutingTable>>>,
|
||||
source: Arc<dyn RoutingSource>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RouterState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouterState").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl RouterState {
|
||||
pub async fn new(source: Arc<dyn RoutingSource>) -> Result<Self, RoutingSourceError> {
|
||||
let cfg = source.load().await?;
|
||||
Ok(Self {
|
||||
table: Arc::new(tokio::sync::RwLock::new(Arc::new(cfg.into()))),
|
||||
source,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn snapshot(&self) -> Arc<RoutingTable> {
|
||||
self.table.read().await.clone()
|
||||
}
|
||||
|
||||
pub async fn reload(&self) -> Result<(), RoutingSourceError> {
|
||||
let cfg = self.source.load().await?;
|
||||
let next = Arc::new(RoutingTable::from(cfg));
|
||||
*self.table.write().await = next;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn start_watcher(&self) -> tokio::task::JoinHandle<()> {
|
||||
let this = self.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut stream = match this.source.watch().await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
while let Some(msg) = stream.next().await {
|
||||
if msg.is_err() {
|
||||
continue;
|
||||
}
|
||||
let _ = this.reload().await;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn resolve(
|
||||
&self,
|
||||
tenant_id: &str,
|
||||
kind: ServiceKind,
|
||||
) -> Result<String, RoutingError> {
|
||||
let table = self.snapshot().await;
|
||||
let result = table.resolve(tenant_id, kind);
|
||||
metrics::counter!(
|
||||
"gateway_routing_resolutions_total",
|
||||
"kind" => kind_label(kind),
|
||||
"result" => if result.is_ok() { "ok" } else { "err" }
|
||||
)
|
||||
.increment(1);
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn kind_label(kind: ServiceKind) -> &'static str {
|
||||
match kind {
|
||||
ServiceKind::Aggregate => "aggregate",
|
||||
ServiceKind::Projection => "projection",
|
||||
ServiceKind::Runner => "runner",
|
||||
}
|
||||
}
|
||||
|
||||
impl RoutingTable {
|
||||
pub fn resolve(&self, tenant_id: &str, kind: ServiceKind) -> Result<String, RoutingError> {
|
||||
let shard_id = match kind {
|
||||
ServiceKind::Aggregate => self.aggregate_placement.get(tenant_id),
|
||||
ServiceKind::Projection => self.projection_placement.get(tenant_id),
|
||||
ServiceKind::Runner => self.runner_placement.get(tenant_id),
|
||||
}
|
||||
.ok_or(RoutingError::UnknownTenant)?;
|
||||
|
||||
let endpoints = match kind {
|
||||
ServiceKind::Aggregate => self.aggregate_shards.get(shard_id),
|
||||
ServiceKind::Projection => self.projection_shards.get(shard_id),
|
||||
ServiceKind::Runner => self.runner_shards.get(shard_id),
|
||||
}
|
||||
.ok_or(RoutingError::MissingShard)?;
|
||||
|
||||
endpoints.first().cloned().ok_or(RoutingError::EmptyShard)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingSourceError {
|
||||
#[error("source error: {0}")]
|
||||
Source(String),
|
||||
#[error("decode error: {0}")]
|
||||
Decode(String),
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait RoutingSource: Send + Sync {
|
||||
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError>;
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
|
||||
RoutingSourceError,
|
||||
>;
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FixedSource {
|
||||
cfg: RoutingConfig,
|
||||
}
|
||||
|
||||
impl FixedSource {
|
||||
pub fn new(cfg: RoutingConfig) -> Self {
|
||||
Self { cfg }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RoutingSource for FixedSource {
|
||||
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
|
||||
Ok(self.cfg.clone())
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
|
||||
RoutingSourceError,
|
||||
> {
|
||||
Ok(Box::pin(futures::stream::empty()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StaticFileSource {
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl StaticFileSource {
|
||||
pub fn new(path: impl Into<String>) -> Self {
|
||||
Self { path: path.into() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RoutingSource for StaticFileSource {
|
||||
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
|
||||
let raw = tokio::fs::read_to_string(&self.path)
|
||||
.await
|
||||
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
|
||||
|
||||
if self.path.ends_with(".json") {
|
||||
serde_json::from_str::<RoutingConfig>(&raw)
|
||||
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
|
||||
} else {
|
||||
let yaml: serde_yaml::Value = serde_yaml::from_str(&raw)
|
||||
.map_err(|e| RoutingSourceError::Decode(e.to_string()))?;
|
||||
let json = serde_json::to_value(yaml)
|
||||
.map_err(|e| RoutingSourceError::Decode(e.to_string()))?;
|
||||
serde_json::from_value::<RoutingConfig>(json)
|
||||
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
|
||||
RoutingSourceError,
|
||||
> {
|
||||
Ok(Box::pin(futures::stream::empty()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct NatsKvSource {
|
||||
kv: async_nats::jetstream::kv::Store,
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl NatsKvSource {
|
||||
pub async fn connect(
|
||||
nats_url: impl Into<String>,
|
||||
bucket: impl Into<String>,
|
||||
key: impl Into<String>,
|
||||
) -> Result<Self, RoutingSourceError> {
|
||||
let nats_url = nats_url.into();
|
||||
let bucket = bucket.into();
|
||||
let key = key.into();
|
||||
|
||||
let client = async_nats::connect(nats_url)
|
||||
.await
|
||||
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
|
||||
let jetstream = async_nats::jetstream::new(client);
|
||||
|
||||
let kv = match jetstream.get_key_value(&bucket).await {
|
||||
Ok(kv) => kv,
|
||||
Err(_) => jetstream
|
||||
.create_key_value(async_nats::jetstream::kv::Config {
|
||||
bucket: bucket.clone(),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| RoutingSourceError::Source(e.to_string()))?,
|
||||
};
|
||||
|
||||
Ok(Self { kv, key })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RoutingSource for NatsKvSource {
|
||||
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
|
||||
let entry = self
|
||||
.kv
|
||||
.entry(&self.key)
|
||||
.await
|
||||
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
|
||||
|
||||
let Some(entry) = entry else {
|
||||
return Ok(RoutingConfig::empty());
|
||||
};
|
||||
|
||||
serde_json::from_slice::<RoutingConfig>(&entry.value)
|
||||
.map_err(|e| RoutingSourceError::Decode(e.to_string()))
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
|
||||
RoutingSourceError,
|
||||
> {
|
||||
let key = self.key.clone();
|
||||
let watch = self
|
||||
.kv
|
||||
.watch(&key)
|
||||
.await
|
||||
.map_err(|e| RoutingSourceError::Source(e.to_string()))?;
|
||||
|
||||
Ok(Box::pin(watch.filter_map(|entry| async move {
|
||||
match entry {
|
||||
Ok(entry) => match entry.operation {
|
||||
async_nats::jetstream::kv::Operation::Put => Some(Ok(())),
|
||||
async_nats::jetstream::kv::Operation::Delete
|
||||
| async_nats::jetstream::kv::Operation::Purge => None,
|
||||
},
|
||||
Err(e) => Some(Err(RoutingSourceError::Source(e.to_string()))),
|
||||
}
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn assert_send_sync<T: Send + Sync>() {}
|
||||
|
||||
#[test]
|
||||
fn router_state_is_send_sync() {
|
||||
assert_send_sync::<RouterState>();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolves_endpoints_for_tenant_service_kind() {
|
||||
let cfg = RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
|
||||
projection_placement: HashMap::from([("t1".to_string(), "p".to_string())]),
|
||||
runner_placement: HashMap::from([("t1".to_string(), "r".to_string())]),
|
||||
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a".to_string()])]),
|
||||
projection_shards: HashMap::from([("p".to_string(), vec!["http://p".to_string()])]),
|
||||
runner_shards: HashMap::from([("r".to_string(), vec!["http://r".to_string()])]),
|
||||
};
|
||||
|
||||
let source: Arc<dyn RoutingSource> = Arc::new(TestSource::new(cfg));
|
||||
let router = RouterState::new(source).await.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
router.resolve("t1", ServiceKind::Aggregate).await.unwrap(),
|
||||
"http://a"
|
||||
);
|
||||
assert_eq!(
|
||||
router.resolve("t1", ServiceKind::Projection).await.unwrap(),
|
||||
"http://p"
|
||||
);
|
||||
assert_eq!(
|
||||
router.resolve("t1", ServiceKind::Runner).await.unwrap(),
|
||||
"http://r"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_tenant_is_typed_error() {
|
||||
let source: Arc<dyn RoutingSource> = Arc::new(TestSource::new(RoutingConfig::empty()));
|
||||
let router = RouterState::new(source).await.unwrap();
|
||||
let err = router
|
||||
.resolve("missing", ServiceKind::Aggregate)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err, RoutingError::UnknownTenant);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn hot_reload_swaps_table_atomically() {
|
||||
let cfg1 = RoutingConfig {
|
||||
revision: 1,
|
||||
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a1".to_string()])]),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
let cfg2 = RoutingConfig {
|
||||
revision: 2,
|
||||
aggregate_placement: HashMap::from([("t1".to_string(), "a".to_string())]),
|
||||
projection_placement: HashMap::new(),
|
||||
runner_placement: HashMap::new(),
|
||||
aggregate_shards: HashMap::from([("a".to_string(), vec!["http://a2".to_string()])]),
|
||||
projection_shards: HashMap::new(),
|
||||
runner_shards: HashMap::new(),
|
||||
};
|
||||
|
||||
let test_source = Arc::new(TestSource::new(cfg1));
|
||||
let router = RouterState::new(test_source.clone()).await.unwrap();
|
||||
|
||||
let before = router.resolve("t1", ServiceKind::Aggregate).await.unwrap();
|
||||
assert_eq!(before, "http://a1");
|
||||
|
||||
test_source.set(cfg2).await;
|
||||
router.reload().await.unwrap();
|
||||
|
||||
let after = router.resolve("t1", ServiceKind::Aggregate).await.unwrap();
|
||||
assert_eq!(after, "http://a2");
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestSource {
|
||||
cfg: Arc<tokio::sync::RwLock<RoutingConfig>>,
|
||||
}
|
||||
|
||||
impl TestSource {
|
||||
fn new(cfg: RoutingConfig) -> Self {
|
||||
Self {
|
||||
cfg: Arc::new(tokio::sync::RwLock::new(cfg)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn set(&self, cfg: RoutingConfig) {
|
||||
*self.cfg.write().await = cfg;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RoutingSource for TestSource {
|
||||
async fn load(&self) -> Result<RoutingConfig, RoutingSourceError> {
|
||||
Ok(self.cfg.read().await.clone())
|
||||
}
|
||||
|
||||
async fn watch(
|
||||
&self,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = Result<(), RoutingSourceError>> + Send>>,
|
||||
RoutingSourceError,
|
||||
> {
|
||||
Ok(Box::pin(futures::stream::empty()))
|
||||
}
|
||||
}
|
||||
}
|
||||
1015
gateway/src/storage.rs
Normal file
1015
gateway/src/storage.rs
Normal file
File diff suppressed because it is too large
Load Diff
99
gateway/src/upstream.rs
Normal file
99
gateway/src/upstream.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn http_client() -> &'static reqwest::Client {
|
||||
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
|
||||
CLIENT.get_or_init(|| {
|
||||
let mut builder = reqwest::Client::builder().timeout(Duration::from_secs(10));
|
||||
|
||||
if let Some(ca_pem) = env_or_file(
|
||||
"GATEWAY_INTERNAL_CA_CERT_PEM",
|
||||
"GATEWAY_INTERNAL_CA_CERT_PEM_FILE",
|
||||
) {
|
||||
if let Ok(cert) = reqwest::Certificate::from_pem(ca_pem.as_bytes()) {
|
||||
builder = builder.add_root_certificate(cert);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(identity_pem) = env_or_file(
|
||||
"GATEWAY_INTERNAL_IDENTITY_PEM",
|
||||
"GATEWAY_INTERNAL_IDENTITY_PEM_FILE",
|
||||
) {
|
||||
if let Ok(identity) = reqwest::Identity::from_pem(identity_pem.as_bytes()) {
|
||||
builder = builder.identity(identity);
|
||||
}
|
||||
}
|
||||
|
||||
builder.build().expect("failed to build reqwest client")
|
||||
})
|
||||
}
|
||||
|
||||
pub fn grpc_endpoint(url: &str) -> Result<tonic::transport::Endpoint, tonic::transport::Error> {
|
||||
let mut endpoint =
|
||||
tonic::transport::Endpoint::from_shared(url.to_string())?.timeout(Duration::from_secs(10));
|
||||
|
||||
let wants_tls = url.starts_with("https://")
|
||||
|| std::env::var("GATEWAY_INTERNAL_GRPC_TLS")
|
||||
.ok()
|
||||
.map(|v| matches!(v.trim().to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if wants_tls {
|
||||
if let Some(tls) = grpc_tls_config() {
|
||||
endpoint = endpoint.tls_config(tls)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
fn grpc_tls_config() -> Option<tonic::transport::ClientTlsConfig> {
|
||||
let mut tls = tonic::transport::ClientTlsConfig::new();
|
||||
let mut configured = false;
|
||||
|
||||
if let Some(ca_pem) = env_or_file(
|
||||
"GATEWAY_INTERNAL_GRPC_CA_CERT_PEM",
|
||||
"GATEWAY_INTERNAL_GRPC_CA_CERT_PEM_FILE",
|
||||
) {
|
||||
tls = tls.ca_certificate(tonic::transport::Certificate::from_pem(ca_pem));
|
||||
configured = true;
|
||||
}
|
||||
|
||||
let cert_pem = env_or_file(
|
||||
"GATEWAY_INTERNAL_GRPC_CLIENT_CERT_PEM",
|
||||
"GATEWAY_INTERNAL_GRPC_CLIENT_CERT_PEM_FILE",
|
||||
);
|
||||
let key_pem = env_or_file(
|
||||
"GATEWAY_INTERNAL_GRPC_CLIENT_KEY_PEM",
|
||||
"GATEWAY_INTERNAL_GRPC_CLIENT_KEY_PEM_FILE",
|
||||
);
|
||||
if let (Some(cert_pem), Some(key_pem)) = (cert_pem, key_pem) {
|
||||
tls = tls.identity(tonic::transport::Identity::from_pem(cert_pem, key_pem));
|
||||
configured = true;
|
||||
}
|
||||
|
||||
if configured {
|
||||
Some(tls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn env_or_file(env_key: &str, file_env_key: &str) -> Option<String> {
|
||||
if let Ok(path) = std::env::var(file_env_key) {
|
||||
if let Ok(raw) = std::fs::read_to_string(path) {
|
||||
let trimmed = raw.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::env::var(env_key).ok().and_then(|v| {
|
||||
let trimmed = v.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user