diff --git a/Cargo.lock b/Cargo.lock index e872ba8..d40fdf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -159,6 +159,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "shared", "thiserror 2.0.18", "tokio", "tower 0.5.3", @@ -3385,6 +3386,8 @@ dependencies = [ "edge_storage", "futures", "libmdbx", + "prost 0.13.5", + "protoc-bin-vendored", "query_engine", "runtime-function", "serde", @@ -3395,6 +3398,8 @@ dependencies = [ "thiserror 2.0.18", "tokio", "toml", + "tonic", + "tonic-build", "tower 0.5.3", "tracing", "tracing-subscriber", diff --git a/TRANSPORT_DEVELOPMENT_PLAN.md b/TRANSPORT_DEVELOPMENT_PLAN.md index b44c857..29cc495 100644 --- a/TRANSPORT_DEVELOPMENT_PLAN.md +++ b/TRANSPORT_DEVELOPMENT_PLAN.md @@ -11,15 +11,17 @@ This plan merges and supersedes: ## Current Status (Codebase Reality) - Monorepo workspace exists; `shared` crate exists and is used by Aggregate/Projection/Runner/Gateway. -- Request context pieces are partially standardized: +- Request context pieces are standardized: - `shared` provides `TenantId`, `CorrelationId`, `TraceId` - `shared` provides `trace_id_from_traceparent(...)` and `traceparent_from_trace_id(...)` - - Some header names are centralized in `shared` but not all call sites use constants yet. + - `shared` provides canonical header constants (HTTP + NATS) and trace/correlation normalization helpers + - Most call sites now use `shared` constants/helpers; remaining gaps should be treated as Milestone-gated - Gateway → Aggregate is already HTTP(edge) → gRPC(internal) and propagates `x-tenant-id`, `x-correlation-id`, and `traceparent`. - Gateway → Projection remains HTTP proxy (`/v1/query/...`) and Gateway → Runner remains HTTP admin proxy (`/admin/runner/...`). - Node → NATS header propagation is improved and closer to consistent: - - Runner publishes `x-correlation-id` and `correlation-id`, and ensures `traceparent`/`trace-id` are present/derived when possible. - - Aggregate publishes `trace-id` when `traceparent` is present. + - Runner publishes required headers for effect commands/results (`tenant-id`, `Nats-Msg-Id`, correlation, traceparent/trace-id), generating when missing. + - Aggregate publishes required headers for events (`tenant-id`, `Nats-Msg-Id`, correlation, traceparent/trace-id), generating when missing. + - Projection hydrates correlation/trace context from NATS headers when the JSON envelope omits them. - Many “hard” NATS tests already exist but are gated/ignored by default; they should be treated as milestone gates when enabling changes. ## Principles @@ -84,19 +86,20 @@ Make propagation rules consistent and enforceable across HTTP, gRPC, and NATS so - [x] `TenantId` - [x] `CorrelationId` - [x] `TraceId` -- [~] Consolidate header constants in `shared`: +- [x] Consolidate header constants in `shared`: - [x] HTTP: `x-correlation-id`, `traceparent`, `trace-id` (for NATS/interop) - - [ ] HTTP: `x-tenant-id`, `x-request-id` (missing constants) + - [x] HTTP: `x-tenant-id`, `x-request-id` - [x] NATS: `correlation-id` (used in Runner), `trace-id` (now emitted where possible) - - [ ] NATS: `tenant-id` constant, `Nats-Msg-Id` constant (missing constants) + - [x] NATS: `tenant-id`, `Nats-Msg-Id` - [x] Add shared helpers: - [x] derive `trace-id` from `traceparent` - [x] derive `traceparent` from `trace-id` when valid - - [ ] normalize/generate correlation id when missing across all transports (helper exists for `CorrelationId::generate()`; adoption incomplete) + - [x] normalize/generate correlation id when missing (`normalize_correlation_id(...)`) + - [x] normalize/generate traceparent when missing/invalid (`normalize_traceparent(...)`) - [x] Add unit tests in `shared` for: - [x] traceparent parsing validity - [x] serialization shape for correlation/trace id newtypes - - [ ] additional validation cases (invalid traceparents, invalid trace-id lengths) if needed for stricter enforcement + - [x] additional validation cases (invalid traceparents, all-zero ids) ### Required Tests - `cargo fmt --check` @@ -118,24 +121,26 @@ Make the JetStream/NATS “wire protocol” explicit and uniform so interop is s - “Contract tests” exist per service to verify produced headers and subject formats. ### Tasks -- [ ] Create/standardize subject builder helpers (prefer `shared`): - - [ ] Aggregate event subject builder (`tenant..aggregate..`) - - [ ] Runner effect/effect_result/workflow subject builders -- [~] Aggregate publishes: - - [ ] `tenant-id` header always present (still needs enforcement everywhere) - - [ ] correlation + trace headers always present when available, generated when required +- [x] Create/standardize subject builder helpers (prefer `shared`): + - [x] Aggregate event subject builder (`tenant..aggregate..`) + - [x] Runner effect/effect_result subject builders + - [x] Runner workflow/workflow_event subject builders (helpers exist; concrete publishers/consumers are future work) +- [x] Aggregate publishes: + - [x] `tenant-id` header always present + - [x] correlation + trace headers always present; generated when missing/invalid - [x] `trace-id` is derived when `traceparent` is present (now emitted in publish path) - - [ ] `Nats-Msg-Id` strategy explicitly defined and tested -- [~] Runner publishes (commands/results): - - [x] correlation headers emitted consistently (`x-correlation-id` + `correlation-id`) - - [x] trace headers derived consistently when possible (`traceparent` from `trace-id`, `trace-id` from `traceparent`) - - [ ] outbox metadata → NATS headers mapping standardized via shared helpers (adoption incomplete) -- [~] Projection consumption: + - [x] `Nats-Msg-Id` strategy explicitly defined and tested (Aggregate events use `event_id`) +- [x] Runner publishes (commands/results): + - [x] correlation headers emitted consistently (`x-correlation-id` + `correlation-id`) and generated when missing + - [x] trace headers always present/derived when possible; generated when missing/invalid + - [x] `Nats-Msg-Id` strategy explicitly defined and tested (Runner commands/results use `command_id`) + - [x] outbox metadata → NATS headers mapping standardized via shared helpers +- [x] Projection consumption: - [x] envelope decoding remains tolerant (unknown fields ignored) - - [~] correlation/trace context flows into spans/metrics consistently (types are shared; header extraction remains best-effort and should be unified) -- [ ] Add unit tests: - - [ ] subject formatting tests per service (once builders exist) - - [ ] required header presence tests per publisher (enforce required keys) + - [x] correlation/trace context flows into spans/metrics consistently (envelope + NATS header fallback) +- [x] Add unit tests: + - [x] subject formatting tests (shared builders) + - [x] required header presence tests per publisher (Aggregate + Runner) ### Required Tests - Workspace verification commands @@ -155,15 +160,16 @@ Make stream definitions explicit, validated, and safe in all environments, preve - Config-only tests validate stream config builders without requiring NATS. ### Tasks -- [ ] Define stream policies: - - [ ] `AGGREGATE_EVENTS` (subjects, retention, duplicate window) - - [ ] `WORKFLOW_COMMANDS` - - [ ] `WORKFLOW_EVENTS` -- [ ] Implement compatibility validation rules: - - [ ] required subjects are present (superset allowed) - - [ ] retention/limits are within allowed ranges - - [ ] dedupe assumptions align with producer `Nats-Msg-Id` usage -- [ ] Add unit tests for stream config builders + validators. +- [x] Define stream policies: + - [x] `AGGREGATE_EVENTS` (subjects, limits, duplicate window) is defined and validated on startup + - [x] `WORKFLOW_COMMANDS` is defined and validated on startup + - [x] `WORKFLOW_EVENTS` is defined and validated on startup + - [x] Centralize stream policy builders/validators in `shared` +- [x] Implement compatibility validation rules: + - [x] required subjects are present (superset allowed) + - [x] limits/max_age/duplicate window validated against minimums + - [x] dedupe assumptions align with producer `Nats-Msg-Id` usage (duplicate window + msg-id strategy) +- [x] Add unit tests for stream config builders + validators. ### Required Tests - Workspace verification commands @@ -186,25 +192,25 @@ Standardize consumer configs and runtime behavior to guarantee bounded in-flight - scale-out behavior (deliver group) where applicable ### Tasks -- [ ] Standardize consumer defaults: - - [ ] `AckPolicy::Explicit` - - [ ] `ack_wait` default + env override - - [ ] `max_deliver` default + env override - - [ ] `max_ack_pending` tied to worker concurrency -- [ ] Projection: - - [ ] durable naming collision-free for Single/PerView modes - - [ ] checkpoint gate semantics: “skip still acks” - - [ ] poison handling persists durable records and terminates reliably -- [ ] Runner: - - [ ] durable naming collision-free and stable across replicas - - [ ] deliver group rules defined and tested - - [ ] outbox relay exactly-once behavior verified under redelivery -- [ ] Aggregate: - - [ ] ad-hoc fetch consumer always unique and bounded - - [ ] best-effort deletion never targets unrelated consumers -- [ ] Add gated NATS integration tests and document env flags: - - [ ] Runner ignored tests - - [ ] Projection ignored tests +- [x] Standardize consumer defaults: + - [x] `AckPolicy::Explicit` + - [x] `ack_wait` default + env override (Runner/Projection: `*_ACK_TIMEOUT_MS`) + - [x] `max_deliver` default + env override (Runner/Projection: `*_MAX_DELIVER`) + - [x] `max_ack_pending` tied to worker concurrency (Runner/Projection: `max_in_flight`) +- [x] Projection: + - [x] durable naming collision-free for Single/PerView modes + - [x] checkpoint gate semantics: “skip still acks” + - [x] poison handling persists durable records and terminates reliably (poison record + term) +- [x] Runner: + - [x] durable naming collision-free and stable across replicas + - [x] deliver group rules defined (pull consumers; `deliver_group` is rejected if configured) + - [x] outbox relay exactly-once behavior verified under redelivery (unit tests exist; gated NATS e2e tests remain ignored-by-default) +- [x] Aggregate: + - [x] ad-hoc fetch consumer always unique and bounded + - [x] best-effort deletion never targets unrelated consumers +- [x] Add gated NATS integration tests and document env flags: + - [x] Runner ignored tests + - [x] Projection ignored tests ### Required Tests - Workspace verification commands @@ -227,17 +233,17 @@ Replace Gateway → Projection HTTP proxy as the default path with a gRPC Query - New gRPC query tests pass (unit + integration). ### Tasks -- [ ] Define protobuf API: `projection.gateway.v1.QueryService` -- [ ] Implement Projection gRPC server for query execution -- [ ] Implement Gateway gRPC client routing to Projection - - [ ] deadlines - - [ ] bounded retries (idempotent only) - - [ ] context propagation -- [ ] Preserve HTTP `/v1/query/*` as compatibility/debug: - - [ ] route internally to gRPC or keep as legacy endpoint -- [ ] Add tests: - - [ ] authz + forwarding via gRPC - - [ ] tenant isolation enforcement in Projection QueryService +- [x] Define protobuf API: `projection.gateway.v1.QueryService` +- [x] Implement Projection gRPC server for query execution +- [x] Implement Gateway gRPC client routing to Projection + - [x] deadlines + - [x] bounded retries (idempotent only) + - [x] context propagation +- [x] Preserve HTTP `/v1/query/*` as compatibility/debug: + - [x] route internally to gRPC +- [x] Add tests: + - [x] authz + forwarding via gRPC + - [x] tenant isolation enforcement in Projection QueryService ### Required Tests - Workspace verification commands @@ -257,14 +263,14 @@ Replace Gateway’s `/admin/runner/*` HTTP proxy usage with a first-class gRPC a - Runner drain/readiness semantics validated and tested. ### Tasks -- [ ] Define protobuf API: `runner.admin.v1.RunnerAdmin` -- [ ] Implement Runner gRPC admin server -- [ ] Implement Gateway gRPC client integration for admin operations -- [ ] Keep Runner HTTP admin endpoints optional for direct debugging, not required by Gateway -- [ ] Add tests: - - [ ] Gateway: rejects without rights - - [ ] Gateway: rejects tenant spoof attempts - - [ ] Runner: idempotency and drain semantics +- [x] Define protobuf API: `runner.admin.v1.RunnerAdmin` +- [x] Implement Runner gRPC admin server +- [x] Implement Gateway gRPC client integration for admin operations +- [x] Keep Runner HTTP admin endpoints optional for direct debugging, not required by Gateway +- [x] Add tests: + - [x] Gateway: rejects without rights + - [x] Gateway: rejects tenant spoof attempts + - [x] Runner: idempotency and drain semantics ### Required Tests - Workspace verification commands @@ -284,20 +290,20 @@ Make Gateway upstream connection handling, retry behavior, and probe/fanout oper - Gated load/soak tests exist and are runnable. ### Tasks -- [ ] Implement upstream channel pool - - [ ] bounded LRU - - [ ] TTL/eviction - - [ ] fast-path reuse under load -- [ ] Standardize retry profiles - - [ ] read-only: limited retry with jitter - - [ ] mutations: no retry unless idempotency key is present and semantics are safe -- [ ] Standardize timeouts/deadlines: - - [ ] edge timeout limits - - [ ] internal per-service deadlines -- [ ] Fanout controls: - - [ ] concurrency limiters for probes/snapshots - - [ ] short TTL caching where safe -- [ ] Ensure probes carry context (correlation/trace) for observability. +- [x] Implement upstream channel pool + - [x] bounded LRU + - [x] TTL/eviction + - [x] fast-path reuse under load (cached gRPC channels) +- [x] Standardize retry profiles + - [x] read-only: limited retry with jitter (Gateway gRPC calls) + - [x] mutations: no retry unless idempotency key is present and semantics are safe (Gateway does not retry mutations) +- [x] Standardize timeouts/deadlines: + - [x] edge timeout limits + - [x] internal per-service deadlines +- [x] Fanout controls: + - [x] concurrency limiters for probes/snapshots + - [x] short TTL caching where safe +- [x] Ensure probes carry context (correlation/trace) for observability. ### Required Tests - Workspace verification commands @@ -317,10 +323,10 @@ Ensure the “happy path” is: HTTP edge → Gateway → gRPC internal → NATS - End-to-end smoke tests pass (gated). ### Tasks -- [ ] Remove Gateway HTTP query proxy usage (or keep only as explicit compatibility shim) -- [ ] Remove Gateway runner admin HTTP proxy usage (or keep only as explicit compatibility shim) -- [ ] Ensure Control UI + Control API rely only on standardized surfaces -- [ ] Harden metrics and readiness probes to match the standard contract everywhere +- [x] Remove Gateway HTTP query proxy usage (kept HTTP edge; Gateway routes internally to Projection gRPC) +- [x] Remove Gateway runner admin HTTP proxy usage (kept HTTP edge; Gateway routes internally to RunnerAdmin gRPC) +- [x] Ensure Control UI + Control API rely only on standardized surfaces +- [x] Harden metrics and readiness probes to match the standard contract everywhere ### Required Tests - Workspace verification commands diff --git a/aggregate/src/gateway/mod.rs b/aggregate/src/gateway/mod.rs index 4dbd515..e787e70 100644 --- a/aggregate/src/gateway/mod.rs +++ b/aggregate/src/gateway/mod.rs @@ -1,4 +1,4 @@ -pub const TENANT_ID_METADATA_KEY: &str = "x-tenant-id"; +pub const TENANT_ID_METADATA_KEY: &str = shared::HEADER_X_TENANT_ID; pub mod proto { tonic::include_proto!("aggregate.gateway.v1"); diff --git a/aggregate/src/gateway/server.rs b/aggregate/src/gateway/server.rs index 97e143c..dd92717 100644 --- a/aggregate/src/gateway/server.rs +++ b/aggregate/src/gateway/server.rs @@ -48,14 +48,14 @@ impl CommandService for GrpcCommandServer { ) -> Result, Status> { let correlation_id = request .metadata() - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .and_then(|v| v.to_str().ok()) .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.to_string()); let traceparent = request .metadata() - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .map(|s| s.trim()) .filter(|s| !s.is_empty()) @@ -172,12 +172,16 @@ impl CommandService for GrpcCommandServer { }); if let Some(correlation_id) = correlation_id.as_deref() { if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id) { - response.metadata_mut().insert("x-correlation-id", v); + response + .metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); } } if let Some(traceparent) = traceparent.as_deref() { if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent) { - response.metadata_mut().insert("traceparent", v); + response + .metadata_mut() + .insert(shared::HEADER_TRACEPARENT, v); } } Ok(response) diff --git a/aggregate/src/server/mod.rs b/aggregate/src/server/mod.rs index 3f5039c..8933273 100644 --- a/aggregate/src/server/mod.rs +++ b/aggregate/src/server/mod.rs @@ -54,7 +54,7 @@ impl CommandRequest { ); if let Some(correlation_id) = self .headers - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .map(|s| s.trim()) .filter(|s| !s.is_empty()) { @@ -65,7 +65,7 @@ impl CommandRequest { } if let Some(traceparent) = self .headers - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .map(|s| s.trim()) .filter(|s| !s.is_empty()) { @@ -124,7 +124,7 @@ impl CommandServer { pub fn extract_tenant_id(&self, headers: &HashMap) -> TenantId { headers - .get("x-tenant-id") + .get(shared::HEADER_X_TENANT_ID) .map(TenantId::new) .unwrap_or_default() } @@ -163,13 +163,13 @@ impl CommandServer { let correlation_id = request .headers - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.to_string()); let trace_id = request .headers - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .map(|s| s.trim()) .filter(|s| !s.is_empty()) .and_then(trace_id_from_traceparent); diff --git a/aggregate/src/stream/mod.rs b/aggregate/src/stream/mod.rs index d0e7559..b9037c3 100644 --- a/aggregate/src/stream/mod.rs +++ b/aggregate/src/stream/mod.rs @@ -18,6 +18,12 @@ use tokio::sync::RwLock; use tokio::time::Instant; const AGGREGATE_STREAM_NAME: &str = "AGGREGATE_EVENTS"; +const FETCH_CONSUMER_MAX_ACK_PENDING: i64 = 256; +const FETCH_CONSUMER_MAX_DELIVER: i64 = 1; +const FETCH_CONSUMER_ACK_WAIT: Duration = Duration::from_secs(3); +const SUBSCRIBE_CONSUMER_MAX_ACK_PENDING: i64 = 256; +const SUBSCRIBE_CONSUMER_MAX_DELIVER: i64 = 10; +const SUBSCRIBE_CONSUMER_ACK_WAIT: Duration = Duration::from_secs(30); #[derive(Debug)] pub struct StreamConfigSettings { @@ -107,21 +113,18 @@ impl StreamClient { } }; - let config = StreamConfig { - name: AGGREGATE_STREAM_NAME.to_string(), - subjects: vec!["tenant.*.aggregate.*.*".to_string()], - max_messages: settings.max_messages, - max_bytes: settings.max_bytes, - max_age: settings.max_age, - duplicate_window: settings.duplicate_window, - ..Default::default() - }; + let expected = stream_policy_config(settings); - let stream = jetstream - .get_or_create_stream(config) + let mut stream = jetstream + .get_or_create_stream(expected.clone()) .await .map_err(|e| AggregateError::StreamError(format!("Failed to create stream: {}", e)))?; + let info = stream.info().await.map_err(|e| { + AggregateError::StreamError(format!("Failed to load stream info: {}", e)) + })?; + validate_stream_config(&expected, &info.config)?; + Ok(stream) } @@ -139,28 +142,16 @@ impl StreamClient { match &self.backend { StreamBackend::JetStream(jetstream) => { for event in &events { - let subject = - build_subject(&event.tenant_id, &event.aggregate_type, &event.aggregate_id); + let subject = shared::nats_subject_aggregate_event( + event.tenant_id.as_str(), + event.aggregate_type.as_str(), + &event.aggregate_id.to_string(), + ); let payload = serde_json::to_vec(event).map_err(|e| { AggregateError::StreamError(format!("Serialization error: {}", e)) })?; - let mut headers = async_nats::HeaderMap::new(); - headers.insert("Nats-Msg-Id", event.event_id.to_string().as_str()); - headers.insert("aggregate-version", event.version.to_string().as_str()); - headers.insert("tenant-id", event.tenant_id.as_str()); - headers.insert("aggregate-type", event.aggregate_type.as_str()); - headers.insert("event-type", event.event_type.as_str()); - if let Some(correlation_id) = event.correlation_id.as_deref() { - headers.insert("x-correlation-id", correlation_id); - headers.insert("correlation-id", correlation_id); - } - if let Some(traceparent) = event.traceparent.as_deref() { - headers.insert("traceparent", traceparent); - if let Some(trace_id) = shared::trace_id_from_traceparent(traceparent) { - headers.insert("trace-id", trace_id); - } - } + let headers = build_event_headers(event); let result = jetstream .publish_with_headers(subject.clone(), headers.clone(), payload.into()) @@ -248,6 +239,9 @@ impl StreamClient { filter_subject: subject.clone(), deliver_policy: DeliverPolicy::All, ack_policy: AckPolicy::Explicit, + ack_wait: FETCH_CONSUMER_ACK_WAIT, + max_ack_pending: FETCH_CONSUMER_MAX_ACK_PENDING, + max_deliver: FETCH_CONSUMER_MAX_DELIVER, replay_policy: ReplayPolicy::Instant, ..Default::default() }; @@ -348,8 +342,14 @@ impl StreamClient { let consumer_name = format!("sub_{}_{}", tenant_id.as_str(), aggregate_id); let consumer_config = PullConfig { + durable_name: Some(consumer_name.clone()), filter_subject: subject, deliver_policy: DeliverPolicy::New, + ack_policy: AckPolicy::Explicit, + ack_wait: SUBSCRIBE_CONSUMER_ACK_WAIT, + replay_policy: ReplayPolicy::Instant, + max_ack_pending: SUBSCRIBE_CONSUMER_MAX_ACK_PENDING, + max_deliver: SUBSCRIBE_CONSUMER_MAX_DELIVER, ..Default::default() }; @@ -487,16 +487,80 @@ impl StreamClient { } } +fn stream_policy_config(settings: StreamConfigSettings) -> StreamConfig { + let policy = shared::stream_policy_defaults( + AGGREGATE_STREAM_NAME.to_string(), + vec![shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()], + ); + StreamConfig { + name: policy.name, + subjects: policy.subjects, + max_messages: settings.max_messages, + max_bytes: settings.max_bytes, + max_age: settings.max_age, + duplicate_window: settings.duplicate_window, + ..Default::default() + } +} + +fn validate_stream_config( + expected: &StreamConfig, + actual: &StreamConfig, +) -> Result<(), AggregateError> { + let expected = shared::stream_policy_from_parts( + expected.name.as_str(), + expected.subjects.clone(), + expected.max_messages, + expected.max_bytes, + expected.max_age, + expected.duplicate_window, + ); + let actual = shared::stream_policy_from_parts( + actual.name.as_str(), + actual.subjects.clone(), + actual.max_messages, + actual.max_bytes, + actual.max_age, + actual.duplicate_window, + ); + shared::validate_stream_policy(&expected, &actual) + .map_err(|e| AggregateError::StreamError(e.to_string())) +} + +fn build_event_headers(event: &Event) -> async_nats::HeaderMap { + let mut headers = async_nats::HeaderMap::new(); + + let aggregate_version = event.version.to_string(); + let aggregate_type = event.aggregate_type.as_str().to_string(); + let event_type = event.event_type.to_string(); + + headers.insert("aggregate-version", aggregate_version); + headers.insert("aggregate-type", aggregate_type); + headers.insert("event-type", event_type); + + let ctx = shared::nats_context_headers_required( + event.tenant_id.as_str(), + Some(&event.event_id.to_string()), + event.correlation_id.as_deref(), + event.traceparent.as_deref(), + None, + ); + for (k, v) in ctx { + headers.insert(k, v); + } + + headers +} + pub fn build_subject( tenant_id: &TenantId, aggregate_type: &AggregateType, aggregate_id: &AggregateId, ) -> String { - format!( - "tenant.{}.aggregate.{}.{}", + shared::nats_subject_aggregate_event( tenant_id.as_str(), aggregate_type.as_str(), - aggregate_id + &aggregate_id.to_string(), ) } @@ -521,6 +585,49 @@ mod tests { assert!(subject.starts_with("tenant.acme-corp.aggregate.")); } + #[test] + fn event_headers_include_required_context() { + let tenant_id = TenantId::new("tenant-a"); + let aggregate_id = AggregateId::new_v7(); + let aggregate_type = AggregateType::from("Account"); + + let event = Event::new( + tenant_id, + aggregate_id, + aggregate_type, + Version::from(1), + "created", + json!({"ok": true}), + uuid::Uuid::now_v7(), + ); + + let headers = build_event_headers(&event); + assert!(headers.get(shared::NATS_HEADER_TENANT_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_NATS_MSG_ID).is_some()); + assert!(headers.get(shared::HEADER_X_CORRELATION_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_CORRELATION_ID).is_some()); + assert!(headers.get(shared::HEADER_TRACEPARENT).is_some()); + assert!(headers.get(shared::HEADER_TRACE_ID).is_some()); + } + + #[test] + fn stream_config_validation_allows_subject_superset() { + let expected = stream_policy_config(StreamConfigSettings::default()); + let mut actual = expected.clone(); + actual + .subjects + .push("tenant.*.aggregate.extra.*".to_string()); + validate_stream_config(&expected, &actual).unwrap(); + } + + #[test] + fn stream_config_validation_rejects_missing_subject() { + let expected = stream_policy_config(StreamConfigSettings::default()); + let mut actual = expected.clone(); + actual.subjects.clear(); + assert!(validate_stream_config(&expected, &actual).is_err()); + } + #[test] fn stream_config_settings_defaults() { let settings = StreamConfigSettings::default(); diff --git a/control/api/Cargo.toml b/control/api/Cargo.toml index 1019c2f..15f5bb3 100644 --- a/control/api/Cargo.toml +++ b/control/api/Cargo.toml @@ -13,6 +13,7 @@ metrics-exporter-prometheus = "0.16.0" reqwest = { version = "0.12.23", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" +shared = { path = "../../shared" } thiserror = "2.0.16" tokio = { version = "1.45.0", features = ["macros", "net", "process", "rt-multi-thread", "signal", "time"] } tower-http = { version = "0.6.6", features = ["trace"] } diff --git a/control/api/src/admin.rs b/control/api/src/admin.rs index 55ff13c..d2ba4d0 100644 --- a/control/api/src/admin.rs +++ b/control/api/src/admin.rs @@ -19,7 +19,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; const HEADER_IDEMPOTENCY_KEY: &str = "idempotency-key"; -const HEADER_TENANT_ID: &str = "x-tenant-id"; +const HEADER_TENANT_ID: &str = shared::HEADER_X_TENANT_ID; pub fn admin_router() -> Router { Router::new() diff --git a/control/api/src/fleet.rs b/control/api/src/fleet.rs index cec43c0..2989663 100644 --- a/control/api/src/fleet.rs +++ b/control/api/src/fleet.rs @@ -50,12 +50,12 @@ pub async fn snapshot_with_context( async fn get_ok(client: &reqwest::Client, url: &str, ctx: Option<&RequestIds>) -> bool { let mut req = client.get(url).timeout(Duration::from_secs(2)); if let Some(ctx) = ctx { - req = req.header("x-request-id", &ctx.request_id); + req = req.header(shared::HEADER_X_REQUEST_ID, &ctx.request_id); if let Some(cid) = &ctx.correlation_id { - req = req.header("x-correlation-id", cid); + req = req.header(shared::HEADER_X_CORRELATION_ID, cid); } if let Some(tp) = &ctx.traceparent { - req = req.header("traceparent", tp); + req = req.header(shared::HEADER_TRACEPARENT, tp); } } diff --git a/control/api/src/lib.rs b/control/api/src/lib.rs index f558c4f..830fa94 100644 --- a/control/api/src/lib.rs +++ b/control/api/src/lib.rs @@ -53,9 +53,9 @@ pub struct RequestIds { pub traceparent: Option, } -const HEADER_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); -const HEADER_CORRELATION_ID: HeaderName = HeaderName::from_static("x-correlation-id"); -const HEADER_TRACEPARENT: HeaderName = HeaderName::from_static("traceparent"); +const HEADER_REQUEST_ID: HeaderName = HeaderName::from_static(shared::HEADER_X_REQUEST_ID); +const HEADER_CORRELATION_ID: HeaderName = HeaderName::from_static(shared::HEADER_X_CORRELATION_ID); +const HEADER_TRACEPARENT: HeaderName = HeaderName::from_static(shared::HEADER_TRACEPARENT); pub fn build_app(state: AppState) -> Router { let trace = TraceLayer::new_for_http() diff --git a/gateway/build.rs b/gateway/build.rs index bee9375..8de2753 100644 --- a/gateway/build.rs +++ b/gateway/build.rs @@ -1,12 +1,25 @@ fn main() -> Result<(), Box> { - let proto_path = "../aggregate/proto/aggregate.proto"; - let proto_dir = "../aggregate/proto"; + let protoc = protoc_bin_vendored::protoc_bin_path()?; + std::env::set_var("PROTOC", protoc); tonic_build::configure() .build_server(true) .build_client(true) - .compile_protos(&[proto_path], &[proto_dir])?; + .compile_protos( + &[ + "../aggregate/proto/aggregate.proto", + "../projection/proto/query.proto", + "../runner/proto/admin.proto", + ], + &[ + "../aggregate/proto", + "../projection/proto", + "../runner/proto", + ], + )?; - println!("cargo:rerun-if-changed={}", proto_path); + println!("cargo:rerun-if-changed=../aggregate/proto/aggregate.proto"); + println!("cargo:rerun-if-changed=../projection/proto/query.proto"); + println!("cargo:rerun-if-changed=../runner/proto/admin.proto"); Ok(()) } diff --git a/gateway/src/admin_iam.rs b/gateway/src/admin_iam.rs index cefdf16..07acf05 100644 --- a/gateway/src/admin_iam.rs +++ b/gateway/src/admin_iam.rs @@ -1422,14 +1422,38 @@ mod tests { #[tokio::test] async fn tenant_admin_can_create_service_account_and_service_can_query() { - let projection_app = axum::Router::new().route( - "/v1/query/TestView", - axum::routing::post(|| async { (StatusCode::OK, r#"{"ok":true}"#) }), - ); + use crate::grpc::projection_proto::query_service_server::QueryService; + + #[derive(Default)] + struct Upstream; + + #[tonic::async_trait] + impl QueryService for Upstream { + async fn execute_query( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> + { + Ok(tonic::Response::new( + crate::grpc::projection_proto::QueryResponse { + json: r#"{"ok":true}"#.to_string(), + }, + )) + } + } + let projection_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let projection_addr = projection_listener.local_addr().unwrap(); + drop(projection_listener); + let projection_url = format!("http://{}", projection_addr); tokio::spawn(async move { - axum::serve(projection_listener, projection_app) + tonic::transport::Server::builder() + .add_service( + crate::grpc::projection_proto::query_service_server::QueryServiceServer::new( + Upstream, + ), + ) + .serve(projection_addr) .await .unwrap(); }); @@ -1446,7 +1470,7 @@ mod tests { aggregate_shards: std::collections::HashMap::new(), projection_shards: std::collections::HashMap::from([( "p".to_string(), - vec![format!("http://{}", projection_addr)], + vec![projection_url], )]), runner_shards: std::collections::HashMap::new(), }; diff --git a/gateway/src/admin_rebalance.rs b/gateway/src/admin_rebalance.rs index afd9c92..cccc9ba 100644 --- a/gateway/src/admin_rebalance.rs +++ b/gateway/src/admin_rebalance.rs @@ -144,6 +144,7 @@ async fn status( async fn gates( State(state): State, + ctx: crate::RequestContext, principal: Principal, Query(q): Query, ) -> Result, AuthzRejection> { @@ -165,24 +166,33 @@ async fn gates( .await .ok(); - let projection_ready = if let Some(ep) = projection_endpoint { - projection_gate_ready(&ep, &q.tenant_id) - .await - .unwrap_or(false) - } else { - false + let projection_fut = async { + if let Some(ep) = projection_endpoint { + projection_gate_ready(&ep, &q.tenant_id, &ctx) + .await + .unwrap_or(false) + } else { + false + } }; - let runner_ready = if let Some(ep) = runner_endpoint { - http_ready(&ep).await.unwrap_or(false) - } else { - false + let runner_fut = async { + if let Some(ep) = runner_endpoint { + http_ready(&ep, &ctx).await.unwrap_or(false) + } else { + false + } }; - let aggregate_ready = if let Some(ep) = aggregate_endpoint { - aggregate_ready(&ep).await.unwrap_or(false) - } else { - false + let aggregate_fut = async { + if let Some(ep) = aggregate_endpoint { + aggregate_ready(&ep, &ctx).await.unwrap_or(false) + } else { + false + } }; + let (projection_ready, runner_ready, aggregate_ready) = + tokio::join!(projection_fut, runner_fut, aggregate_fut); + Ok(Json(GatesResponse { tenant_id: q.tenant_id, aggregate_ready, @@ -191,35 +201,49 @@ async fn gates( })) } -async fn http_ready(endpoint: &str) -> Result { +async fn http_ready(endpoint: &str, ctx: &crate::RequestContext) -> Result { let url = format!("{}/ready", endpoint.trim_end_matches('/')); - let client = crate::upstream::http_client(); - let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send()) - .await - .map_err(|_| AuthzRejection::Internal)? - .map_err(|_| AuthzRejection::Internal)?; - Ok(resp.status().is_success()) + crate::upstream::probe_status_ok( + &url, + &[ + (shared::HEADER_X_CORRELATION_ID, ctx.correlation_id.as_str()), + (shared::HEADER_TRACEPARENT, ctx.traceparent.as_str()), + ], + Duration::from_secs(2), + Duration::from_millis(500), + ) + .await + .map_err(|_| AuthzRejection::Internal) } -async fn aggregate_ready(endpoint: &str) -> Result { +async fn aggregate_ready( + endpoint: &str, + ctx: &crate::RequestContext, +) -> Result { if endpoint.contains(":50051") { let http_ep = endpoint.replace(":50051", ":8080"); - return http_ready(&http_ep).await; + return http_ready(&http_ep, ctx).await; } - http_ready(endpoint).await + http_ready(endpoint, ctx).await } -async fn projection_gate_ready(endpoint: &str, tenant_id: &str) -> Result { +async fn projection_gate_ready( + endpoint: &str, + tenant_id: &str, + ctx: &crate::RequestContext, +) -> Result { let url = format!("{}/metrics", endpoint.trim_end_matches('/')); - let client = crate::upstream::http_client(); - let resp = tokio::time::timeout(Duration::from_secs(2), client.get(url).send()) - .await - .map_err(|_| AuthzRejection::Internal)? - .map_err(|_| AuthzRejection::Internal)?; - if !resp.status().is_success() { - return Ok(false); - } - let text = resp.text().await.map_err(|_| AuthzRejection::Internal)?; + let text = crate::upstream::probe_text( + &url, + &[ + (shared::HEADER_X_CORRELATION_ID, ctx.correlation_id.as_str()), + (shared::HEADER_TRACEPARENT, ctx.traceparent.as_str()), + ], + Duration::from_secs(2), + Duration::from_millis(250), + ) + .await + .map_err(|_| AuthzRejection::Internal)?; let ready = parse_prom_gauge(&text, "projection_ready").unwrap_or(0.0) >= 1.0; if !ready { diff --git a/gateway/src/authz.rs b/gateway/src/authz.rs index 4f31df3..98fd2b4 100644 --- a/gateway/src/authz.rs +++ b/gateway/src/authz.rs @@ -81,7 +81,7 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let raw = parts .headers - .get("x-tenant-id") + .get(shared::HEADER_X_TENANT_ID) .and_then(|v| v.to_str().ok()) .ok_or(AuthzRejection::MissingTenant)?; @@ -239,33 +239,45 @@ async fn query_stub( ) .await?; - let upstream = state - .routing - .resolve(&tenant_id, crate::routing::ServiceKind::Projection) + let uqf = payload + .get("uqf") + .and_then(|v| v.as_str()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| payload.to_string()); + if uqf.trim().is_empty() { + return Err(AuthzRejection::BadRequest); + } + + let request = crate::grpc::projection_proto::QueryRequest { + tenant_id: tenant_id.clone(), + view_type, + uqf, + }; + + let resp = crate::grpc::execute_query_via_routing(&state.routing, request, &ctx) .await - .map_err(|_| AuthzRejection::Internal)?; - tracing::Span::current().record("upstream", upstream.as_str()); + .map_err(map_query_error)?; - let url = format!("{}/v1/query/{}", upstream.trim_end_matches('/'), view_type); - - let client = crate::upstream::http_client(); - let resp = client - .post(url) - .header("x-tenant-id", tenant_id) - .header("x-correlation-id", ctx.correlation_id) - .header("traceparent", ctx.traceparent) - .json(&payload) - .send() - .await - .map_err(|_| AuthzRejection::Internal)?; - - let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?; - let mut out = Response::new(axum::body::Body::from(bytes)); - *out.status_mut() = status; + let mut out = Response::new(axum::body::Body::from(resp.json)); + out.headers_mut().insert( + header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("application/json"), + ); Ok(out) } +fn map_query_error(status: tonic::Status) -> AuthzRejection { + match status.code() { + tonic::Code::InvalidArgument => AuthzRejection::BadRequest, + tonic::Code::NotFound => AuthzRejection::NotFound, + tonic::Code::PermissionDenied => AuthzRejection::Forbidden, + tonic::Code::Unauthenticated => AuthzRejection::Unauthorized, + tonic::Code::Unavailable => AuthzRejection::Internal, + _ => AuthzRejection::Internal, + } +} + pub async fn runner_admin_proxy( State(state): State, ctx: crate::RequestContext, @@ -282,51 +294,73 @@ pub async fn runner_admin_proxy( ) .await?; - let upstream = state - .routing - .resolve(&tenant_id, crate::routing::ServiceKind::Runner) - .await - .map_err(|_| AuthzRejection::Internal)?; - tracing::Span::current().record("upstream", upstream.as_str()); + let path = path.trim_start_matches('/').to_string(); + match (request.method().as_str(), path.as_str()) { + ("POST", "drain") => { + let wait_ms = request.uri().query().and_then(|q| { + q.split('&').find_map(|pair| { + let (k, v) = pair.split_once('=')?; + if k == "wait_ms" { + v.parse::().ok() + } else { + None + } + }) + }); - let mut url = format!( - "{}/admin/{}", - upstream.trim_end_matches('/'), - path.trim_start_matches('/') - ); - if let Some(q) = request.uri().query() { - url.push('?'); - url.push_str(q); - } - - let method = request.method().clone(); - let headers = request.headers().clone(); - let body = axum::body::to_bytes(request.into_body(), usize::MAX) - .await - .map_err(|_| AuthzRejection::Internal)?; - - let client = crate::upstream::http_client(); - let mut req = client - .request(method, url) - .header("x-tenant-id", tenant_id) - .header("x-correlation-id", ctx.correlation_id) - .header("traceparent", ctx.traceparent) - .body(body); - - for (k, v) in headers.iter() { - if k == header::HOST { - continue; + let resp = crate::grpc::runner_admin_drain_via_routing( + &state.routing, + &tenant_id, + wait_ms, + &ctx, + ) + .await + .map_err(map_query_error)?; + let status = + StatusCode::from_u16(resp.http_status as u16).unwrap_or(StatusCode::BAD_GATEWAY); + let mut out = Response::new(axum::body::Body::from(resp.json)); + *out.status_mut() = status; + out.headers_mut().insert( + header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("application/json"), + ); + Ok(out) } - req = req.header(k, v); + ("GET", "drain/status") => { + let resp = crate::grpc::runner_admin_drain_status_via_routing( + &state.routing, + &tenant_id, + &ctx, + ) + .await + .map_err(map_query_error)?; + let status = + StatusCode::from_u16(resp.http_status as u16).unwrap_or(StatusCode::BAD_GATEWAY); + let mut out = Response::new(axum::body::Body::from(resp.json)); + *out.status_mut() = status; + out.headers_mut().insert( + header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("application/json"), + ); + Ok(out) + } + ("POST", "reload") => { + let resp = + crate::grpc::runner_admin_reload_via_routing(&state.routing, &tenant_id, &ctx) + .await + .map_err(map_query_error)?; + let status = + StatusCode::from_u16(resp.http_status as u16).unwrap_or(StatusCode::BAD_GATEWAY); + let mut out = Response::new(axum::body::Body::from(resp.json)); + *out.status_mut() = status; + out.headers_mut().insert( + header::CONTENT_TYPE, + axum::http::HeaderValue::from_static("application/json"), + ); + Ok(out) + } + _ => Err(AuthzRejection::NotFound), } - - let resp = req.send().await.map_err(|_| AuthzRejection::Internal)?; - let status = StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY); - let bytes = resp.bytes().await.map_err(|_| AuthzRejection::Internal)?; - - let mut out = Response::new(axum::body::Body::from(bytes)); - *out.status_mut() = status; - Ok(out) } pub async fn ensure_allowed( @@ -739,34 +773,60 @@ mod tests { use crate::routing::RoutingConfig; use std::collections::HashMap; - let projection_app = axum::Router::new().route( - "/v1/query/TestView", - post(|headers: axum::http::HeaderMap| async move { - let correlation = headers - .get("x-correlation-id") + use crate::grpc::projection_proto::query_service_server::QueryService; + + #[derive(Default)] + struct Upstream; + + #[async_trait::async_trait] + impl QueryService for Upstream { + async fn execute_query( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> + { + let correlation = request + .metadata() + .get(shared::HEADER_X_CORRELATION_ID) .and_then(|v| v.to_str().ok()) .unwrap_or(""); - let traceparent = headers - .get("traceparent") + let traceparent = request + .metadata() + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .unwrap_or(""); if correlation.trim().is_empty() - || crate::trace_id_from_traceparent(traceparent).is_none() + || shared::trace_id_from_traceparent(traceparent).is_none() { - return (StatusCode::BAD_REQUEST, "missing correlation"); + return Err(tonic::Status::failed_precondition( + "missing correlation metadata", + )); } - (StatusCode::OK, r#"{"mode":"count"}"#) - }), - ); + + Ok(tonic::Response::new( + crate::grpc::projection_proto::QueryResponse { + json: r#"{"mode":"count"}"#.to_string(), + }, + )) + } + } + let projection_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let projection_addr = projection_listener.local_addr().unwrap(); + drop(projection_listener); + let projection_url = format!("http://{}", projection_addr); tokio::spawn(async move { - axum::serve(projection_listener, projection_app) + tonic::transport::Server::builder() + .add_service( + crate::grpc::projection_proto::query_service_server::QueryServiceServer::new( + Upstream, + ), + ) + .serve(projection_addr) .await .unwrap(); }); tokio::time::sleep(std::time::Duration::from_millis(50)).await; - let projection_url = format!("http://{}", projection_addr); let cfg = RoutingConfig { revision: 1, @@ -788,7 +848,7 @@ mod tests { .method("POST") .uri("/v1/query/TestView") .header("authorization", format!("Bearer {token}")) - .header("x-tenant-id", "tenant-a") + .header(shared::HEADER_X_TENANT_ID, "tenant-a") .header("content-type", "application/json") .body(axum::body::Body::from(r#"{"uqf":"{}"}"#)) .unwrap(), @@ -814,7 +874,7 @@ mod tests { .method("POST") .uri("/v1/query/TestView") .header("authorization", format!("Bearer {token}")) - .header("x-tenant-id", "tenant-a") + .header(shared::HEADER_X_TENANT_ID, "tenant-a") .header("content-type", "application/json") .body(axum::body::Body::from(r#"{"uqf":"{}"}"#)) .unwrap(), @@ -824,16 +884,175 @@ mod tests { assert_eq!(ok.status(), StatusCode::OK); assert!(!ok .headers() - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .and_then(|v| v.to_str().ok()) .unwrap_or("") .is_empty()); - assert!(crate::trace_id_from_traceparent( + assert!(shared::trace_id_from_traceparent( ok.headers() - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .unwrap_or("") ) .is_some()); } + + #[tokio::test] + async fn runner_admin_proxy_denies_unauthorized_and_forwards_when_authorized() { + use crate::grpc::runner_admin_proto::runner_admin_server::RunnerAdmin; + use std::collections::HashMap; + + #[derive(Default)] + struct Upstream; + + #[tonic::async_trait] + impl RunnerAdmin for Upstream { + async fn drain( + &self, + _request: tonic::Request, + ) -> Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + crate::grpc::runner_admin_proto::AdminResponse { + http_status: 200, + json: r#"{"ok":true}"#.to_string(), + }, + )) + } + + async fn drain_status( + &self, + _request: tonic::Request, + ) -> Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + crate::grpc::runner_admin_proto::AdminResponse { + http_status: 202, + json: r#"{"ok":true,"drained":false}"#.to_string(), + }, + )) + } + + async fn reload( + &self, + _request: tonic::Request, + ) -> Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + crate::grpc::runner_admin_proto::AdminResponse { + http_status: 200, + json: r#"{"ok":true}"#.to_string(), + }, + )) + } + } + + let runner_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let runner_addr = runner_listener.local_addr().unwrap(); + drop(runner_listener); + let runner_url = format!("http://{}", runner_addr); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service( + crate::grpc::runner_admin_proto::runner_admin_server::RunnerAdminServer::new( + Upstream, + ), + ) + .serve(runner_addr) + .await + .unwrap(); + }); + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let cfg = crate::routing::RoutingConfig { + revision: 1, + aggregate_placement: HashMap::new(), + projection_placement: HashMap::new(), + runner_placement: HashMap::from([("tenant-a".to_string(), "r".to_string())]), + aggregate_shards: HashMap::new(), + projection_shards: HashMap::new(), + runner_shards: HashMap::from([("r".to_string(), vec![runner_url])]), + }; + + let (app, state) = test_app_with_routing(cfg).await; + let (token, claims) = signup_and_get_claims(&app, &state.authn).await; + + let forbidden = app + .clone() + .oneshot( + axum::http::Request::builder() + .method("POST") + .uri("/admin/runner/drain") + .header("authorization", format!("Bearer {token}")) + .header("x-tenant-id", "tenant-a") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(forbidden.status(), StatusCode::FORBIDDEN); + + put_role( + &state.storage, + "role-runner-admin", + vec!["runner.admin".to_string()], + ) + .await + .unwrap(); + assign_role(&state.storage, "tenant-a", &claims.sub, "role-runner-admin") + .await + .unwrap(); + + let ok = app + .oneshot( + axum::http::Request::builder() + .method("POST") + .uri("/admin/runner/drain?wait_ms=0") + .header("authorization", format!("Bearer {token}")) + .header("x-tenant-id", "tenant-a") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(ok.status(), StatusCode::OK); + } + + #[tokio::test] + async fn runner_admin_proxy_rejects_tenant_spoofing() { + let cfg = crate::routing::RoutingConfig::empty(); + let (app, state) = test_app_with_routing(cfg).await; + let (token, claims) = signup_and_get_claims(&app, &state.authn).await; + + put_role( + &state.storage, + "role-runner-admin", + vec!["runner.admin".to_string()], + ) + .await + .unwrap(); + assign_role(&state.storage, "tenant-a", &claims.sub, "role-runner-admin") + .await + .unwrap(); + + let forbidden = app + .oneshot( + axum::http::Request::builder() + .method("POST") + .uri("/admin/runner/reload") + .header("authorization", format!("Bearer {token}")) + .header("x-tenant-id", "tenant-b") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(forbidden.status(), StatusCode::FORBIDDEN); + } } diff --git a/gateway/src/grpc.rs b/gateway/src/grpc.rs index dd3a01d..1531147 100644 --- a/gateway/src/grpc.rs +++ b/gateway/src/grpc.rs @@ -1,11 +1,47 @@ use crate::routing::RouterState; use crate::routing::RoutingError; use crate::routing::ServiceKind; +use std::future::Future; pub mod proto { tonic::include_proto!("aggregate.gateway.v1"); } +pub mod projection_proto { + tonic::include_proto!("projection.gateway.v1"); +} + +pub mod runner_admin_proto { + tonic::include_proto!("runner.admin.v1"); +} + +async fn retry_read_only(mut f: F) -> Result +where + F: FnMut() -> Fut, + Fut: Future>, +{ + let mut last = None; + for attempt in 0..3 { + match f().await { + Ok(v) => return Ok(v), + Err(status) => { + let retryable = matches!( + status.code(), + tonic::Code::Unavailable | tonic::Code::DeadlineExceeded + ); + if retryable && attempt < 2 { + let backoff_ms = 25_u64.saturating_mul(2_u64.pow(attempt as u32)); + tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await; + last = Some(status); + continue; + } + return Err(status); + } + } + } + Err(last.unwrap_or_else(|| tonic::Status::internal("retry exhausted"))) +} + #[derive(Clone)] pub struct GatewayCommandService { routing: RouterState, @@ -23,33 +59,20 @@ impl proto::command_service_server::CommandService for GatewayCommandService { &self, request: tonic::Request, ) -> Result, tonic::Status> { - let correlation_id = request - .metadata() - .get("x-correlation-id") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()) - .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let correlation_id = shared::normalize_correlation_id( + request + .metadata() + .get(shared::HEADER_X_CORRELATION_ID) + .and_then(|v| v.to_str().ok()), + ) + .to_string(); - let traceparent = request - .metadata() - .get("traceparent") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .and_then(|s| { - if crate::trace_id_from_traceparent(s).is_some() { - Some(s.to_string()) - } else { - None - } - }) - .unwrap_or_else(|| { - let trace_id = uuid::Uuid::new_v4().simple().to_string(); - let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string(); - format!("00-{trace_id}-{span_id}-01") - }); + let traceparent = shared::normalize_traceparent( + request + .metadata() + .get(shared::HEADER_TRACEPARENT) + .and_then(|v| v.to_str().ok()), + ); let mut req = request.into_inner(); @@ -66,30 +89,35 @@ impl proto::command_service_server::CommandService for GatewayCommandService { .map_err(map_routing_error)?; tracing::Span::current().record("upstream", upstream.as_str()); - let channel = crate::upstream::grpc_endpoint(&upstream) - .map_err(|e| tonic::Status::unavailable(e.to_string()))? - .connect() - .await + let channel = crate::upstream::grpc_channel(&upstream) .map_err(|e| tonic::Status::unavailable(e.to_string()))?; let mut client = proto::command_service_client::CommandServiceClient::new(channel); let mut upstream_req = tonic::Request::new(req); + upstream_req.set_timeout(std::time::Duration::from_secs(5)); if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) { - upstream_req.metadata_mut().insert("x-tenant-id", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_TENANT_ID, v); } if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) { - upstream_req.metadata_mut().insert("x-correlation-id", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); } if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) { - upstream_req.metadata_mut().insert("traceparent", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_TRACEPARENT, v); } let mut resp = client.submit_command(upstream_req).await?; if let Ok(v) = tonic::metadata::MetadataValue::try_from(correlation_id.as_str()) { - resp.metadata_mut().insert("x-correlation-id", v); + resp.metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); } if let Ok(v) = tonic::metadata::MetadataValue::try_from(traceparent.as_str()) { - resp.metadata_mut().insert("traceparent", v); + resp.metadata_mut().insert(shared::HEADER_TRACEPARENT, v); } Ok(resp) } @@ -111,28 +139,176 @@ pub async fn submit_command_via_routing( .map_err(map_routing_error)?; tracing::Span::current().record("upstream", upstream.as_str()); - let channel = crate::upstream::grpc_endpoint(&upstream) - .map_err(|e| tonic::Status::unavailable(e.to_string()))? - .connect() - .await + let channel = crate::upstream::grpc_channel(&upstream) .map_err(|e| tonic::Status::unavailable(e.to_string()))?; let mut client = proto::command_service_client::CommandServiceClient::new(channel); let mut upstream_req = tonic::Request::new(request); + upstream_req.set_timeout(std::time::Duration::from_secs(5)); if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) { - upstream_req.metadata_mut().insert("x-tenant-id", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_TENANT_ID, v); } if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) { - upstream_req.metadata_mut().insert("x-correlation-id", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); } if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) { - upstream_req.metadata_mut().insert("traceparent", v); + upstream_req + .metadata_mut() + .insert(shared::HEADER_TRACEPARENT, v); } let resp = client.submit_command(upstream_req).await?; Ok(resp.into_inner()) } +pub async fn execute_query_via_routing( + routing: &RouterState, + request: projection_proto::QueryRequest, + ctx: &crate::RequestContext, +) -> Result { + let tenant_id = request.tenant_id.trim().to_string(); + if tenant_id.is_empty() { + return Err(tonic::Status::invalid_argument("tenant_id is required")); + } + + let upstream = routing + .resolve(&tenant_id, ServiceKind::Projection) + .await + .map_err(map_routing_error)?; + tracing::Span::current().record("upstream", upstream.as_str()); + + let channel = crate::upstream::grpc_channel(&upstream) + .map_err(|e| tonic::Status::unavailable(e.to_string()))?; + + retry_read_only(|| { + let mut client = + projection_proto::query_service_client::QueryServiceClient::new(channel.clone()); + let mut upstream_req = tonic::Request::new(request.clone()); + upstream_req.set_timeout(std::time::Duration::from_secs(2)); + if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id.as_str()) { + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_TENANT_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) { + upstream_req + .metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) { + upstream_req + .metadata_mut() + .insert(shared::HEADER_TRACEPARENT, v); + } + async move { Ok(client.execute_query(upstream_req).await?.into_inner()) } + }) + .await +} + +pub async fn runner_admin_drain_via_routing( + routing: &RouterState, + tenant_id: &str, + wait_ms: Option, + ctx: &crate::RequestContext, +) -> Result { + let upstream = routing + .resolve(tenant_id, ServiceKind::Runner) + .await + .map_err(map_routing_error)?; + tracing::Span::current().record("upstream", upstream.as_str()); + + let channel = crate::upstream::grpc_channel(&upstream) + .map_err(|e| tonic::Status::unavailable(e.to_string()))?; + + let mut client = runner_admin_proto::runner_admin_client::RunnerAdminClient::new(channel); + let mut req = tonic::Request::new(runner_admin_proto::DrainRequest { + tenant_id: tenant_id.to_string(), + wait_ms: wait_ms.unwrap_or(0), + }); + req.set_timeout(std::time::Duration::from_secs(5)); + if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id) { + req.metadata_mut().insert(shared::HEADER_X_TENANT_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) { + req.metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) { + req.metadata_mut().insert(shared::HEADER_TRACEPARENT, v); + } + Ok(client.drain(req).await?.into_inner()) +} + +pub async fn runner_admin_drain_status_via_routing( + routing: &RouterState, + tenant_id: &str, + ctx: &crate::RequestContext, +) -> Result { + let upstream = routing + .resolve(tenant_id, ServiceKind::Runner) + .await + .map_err(map_routing_error)?; + tracing::Span::current().record("upstream", upstream.as_str()); + + let channel = crate::upstream::grpc_channel(&upstream) + .map_err(|e| tonic::Status::unavailable(e.to_string()))?; + + retry_read_only(|| { + let mut client = + runner_admin_proto::runner_admin_client::RunnerAdminClient::new(channel.clone()); + let mut req = tonic::Request::new(runner_admin_proto::DrainStatusRequest { + tenant_id: tenant_id.to_string(), + }); + req.set_timeout(std::time::Duration::from_secs(2)); + if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id) { + req.metadata_mut().insert(shared::HEADER_X_TENANT_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) { + req.metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) { + req.metadata_mut().insert(shared::HEADER_TRACEPARENT, v); + } + async move { Ok(client.drain_status(req).await?.into_inner()) } + }) + .await +} + +pub async fn runner_admin_reload_via_routing( + routing: &RouterState, + tenant_id: &str, + ctx: &crate::RequestContext, +) -> Result { + let upstream = routing + .resolve(tenant_id, ServiceKind::Runner) + .await + .map_err(map_routing_error)?; + tracing::Span::current().record("upstream", upstream.as_str()); + + let channel = crate::upstream::grpc_channel(&upstream) + .map_err(|e| tonic::Status::unavailable(e.to_string()))?; + + let mut client = runner_admin_proto::runner_admin_client::RunnerAdminClient::new(channel); + let mut req = tonic::Request::new(runner_admin_proto::ReloadRequest {}); + req.set_timeout(std::time::Duration::from_secs(2)); + if let Ok(v) = tonic::metadata::MetadataValue::try_from(tenant_id) { + req.metadata_mut().insert(shared::HEADER_X_TENANT_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.correlation_id.as_str()) { + req.metadata_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); + } + if let Ok(v) = tonic::metadata::MetadataValue::try_from(ctx.traceparent.as_str()) { + req.metadata_mut().insert(shared::HEADER_TRACEPARENT, v); + } + Ok(client.reload(req).await?.into_inner()) +} + fn map_routing_error(err: RoutingError) -> tonic::Status { match err { RoutingError::UnknownTenant => tonic::Status::not_found("unknown tenant"), @@ -187,7 +363,7 @@ mod tests { .get("traceparent") .and_then(|v| v.to_str().ok()) .unwrap_or(""); - if crate::trace_id_from_traceparent(traceparent).is_none() { + if shared::trace_id_from_traceparent(traceparent).is_none() { return Err(tonic::Status::failed_precondition("missing traceparent")); } @@ -258,9 +434,9 @@ mod tests { .and_then(|v| v.to_str().ok()) .unwrap_or("") .is_empty()); - assert!(crate::trace_id_from_traceparent( + assert!(shared::trace_id_from_traceparent( resp.metadata() - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .unwrap_or("") ) diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index d8e37f7..2ec7254 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -49,23 +49,23 @@ where async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { let request_id = parts .headers - .get("x-request-id") + .get(shared::HEADER_X_REQUEST_ID) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); let correlation_id = parts .headers - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); let traceparent = parts .headers - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); - let trace_id = trace_id_from_traceparent(&traceparent) + let trace_id = shared::trace_id_from_traceparent(&traceparent) .map(|s| s.to_string()) .unwrap_or_default(); @@ -92,7 +92,7 @@ struct StatusResponse { } pub fn app(state: AppState) -> Router { - let request_id_header = HeaderName::from_static("x-request-id"); + let request_id_header = HeaderName::from_static(shared::HEADER_X_REQUEST_ID); Router::new() .route("/health", get(health)) @@ -133,20 +133,20 @@ pub fn app(state: AppState) -> Router { |request: &axum::http::Request<_>| { let request_id = request .headers() - .get("x-request-id") + .get(shared::HEADER_X_REQUEST_ID) .and_then(|v| v.to_str().ok()) .unwrap_or(""); let correlation_id = request .headers() - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .and_then(|v| v.to_str().ok()) .unwrap_or(""); let traceparent = request .headers() - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .and_then(|v| v.to_str().ok()) .unwrap_or(""); - let trace_id = trace_id_from_traceparent(traceparent).unwrap_or(""); + let trace_id = shared::trace_id_from_traceparent(traceparent).unwrap_or(""); let path = request_path_for_logging(request); tracing::span!( @@ -205,48 +205,42 @@ where } fn call(&mut self, mut req: axum::http::Request) -> Self::Future { - let correlation_id = req - .headers() - .get("x-correlation-id") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()) - .unwrap_or_else(generate_correlation_id); + let correlation_id = shared::normalize_correlation_id( + req.headers() + .get(shared::HEADER_X_CORRELATION_ID) + .and_then(|v| v.to_str().ok()), + ) + .to_string(); - let traceparent = req - .headers() - .get("traceparent") - .and_then(|v| v.to_str().ok()) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .and_then(|s| { - if trace_id_from_traceparent(s).is_some() { - Some(s.to_string()) - } else { - None - } - }) - .unwrap_or_else(generate_traceparent); + let traceparent = shared::normalize_traceparent( + req.headers() + .get(shared::HEADER_TRACEPARENT) + .and_then(|v| v.to_str().ok()), + ); if let Ok(v) = HeaderValue::from_str(&correlation_id) { - req.headers_mut().insert("x-correlation-id", v); + req.headers_mut().insert(shared::HEADER_X_CORRELATION_ID, v); } if let Ok(v) = HeaderValue::from_str(&traceparent) { - req.headers_mut().insert("traceparent", v); + req.headers_mut().insert(shared::HEADER_TRACEPARENT, v); } let mut inner = self.inner.clone(); Box::pin(async move { let mut resp = inner.call(req).await?; - if resp.headers().get("x-correlation-id").is_none() { + if resp + .headers() + .get(shared::HEADER_X_CORRELATION_ID) + .is_none() + { if let Ok(v) = HeaderValue::from_str(&correlation_id) { - resp.headers_mut().insert("x-correlation-id", v); + resp.headers_mut() + .insert(shared::HEADER_X_CORRELATION_ID, v); } } - if resp.headers().get("traceparent").is_none() { + if resp.headers().get(shared::HEADER_TRACEPARENT).is_none() { if let Ok(v) = HeaderValue::from_str(&traceparent) { - resp.headers_mut().insert("traceparent", v); + resp.headers_mut().insert(shared::HEADER_TRACEPARENT, v); } } Ok(resp) @@ -254,20 +248,6 @@ where } } -fn generate_correlation_id() -> String { - uuid::Uuid::new_v4().to_string() -} - -fn generate_traceparent() -> String { - let trace_id = uuid::Uuid::new_v4().simple().to_string(); - let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string(); - format!("00-{trace_id}-{span_id}-01") -} - -pub(crate) fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> { - shared::trace_id_from_traceparent(traceparent) -} - async fn track_http_metrics( req: axum::http::Request, next: Next, diff --git a/gateway/src/upstream.rs b/gateway/src/upstream.rs index 0657ec1..f3a4034 100644 --- a/gateway/src/upstream.rs +++ b/gateway/src/upstream.rs @@ -1,5 +1,6 @@ -use std::sync::OnceLock; -use std::time::Duration; +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; pub fn http_client() -> &'static reqwest::Client { static CLIENT: OnceLock = OnceLock::new(); @@ -47,6 +48,175 @@ pub fn grpc_endpoint(url: &str) -> Result Result { + const MAX_CHANNELS: usize = 64; + const TTL: Duration = Duration::from_secs(300); + + static CACHE: OnceLock>> = + OnceLock::new(); + let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new())); + + if let Ok(mut guard) = cache.lock() { + if let Some((channel, last_used)) = guard.get_mut(url) { + if last_used.elapsed() < TTL { + *last_used = Instant::now(); + return Ok(channel.clone()); + } + } + + let endpoint = grpc_endpoint(url)?; + let channel = endpoint.connect_lazy(); + + if guard.len() >= MAX_CHANNELS { + let mut oldest_key = None; + let mut oldest_at = Instant::now(); + for (k, (_, last_used)) in guard.iter() { + if oldest_key.is_none() || *last_used < oldest_at { + oldest_key = Some(k.clone()); + oldest_at = *last_used; + } + } + if let Some(key) = oldest_key { + guard.remove(&key); + } + } + + guard.insert(url.to_string(), (channel.clone(), Instant::now())); + Ok(channel) + } else { + let endpoint = grpc_endpoint(url)?; + Ok(endpoint.connect_lazy()) + } +} + +pub async fn probe_status_ok( + url: &str, + headers: &[(&str, &str)], + timeout: Duration, + cache_ttl: Duration, +) -> Result { + const MAX_ENTRIES: usize = 256; + + static SEM: OnceLock = OnceLock::new(); + static CACHE: OnceLock>> = OnceLock::new(); + + let sem = SEM.get_or_init(|| tokio::sync::Semaphore::new(32)); + let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new())); + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if let Some((value, last_used)) = guard.get_mut(url) { + if last_used.elapsed() < cache_ttl { + *last_used = Instant::now(); + return Ok(*value); + } + } + } + } + + let _permit = sem.acquire().await.expect("probe semaphore closed"); + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if let Some((value, last_used)) = guard.get_mut(url) { + if last_used.elapsed() < cache_ttl { + *last_used = Instant::now(); + return Ok(*value); + } + } + } + } + + let client = http_client(); + let mut req = client.get(url).timeout(timeout); + for (k, v) in headers { + req = req.header(*k, *v); + } + let ok = req.send().await.map(|r| r.status().is_success())?; + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if guard.len() >= MAX_ENTRIES { + evict_oldest(&mut guard); + } + guard.insert(url.to_string(), (ok, Instant::now())); + } + } + + Ok(ok) +} + +pub async fn probe_text( + url: &str, + headers: &[(&str, &str)], + timeout: Duration, + cache_ttl: Duration, +) -> Result { + const MAX_ENTRIES: usize = 128; + + static SEM: OnceLock = OnceLock::new(); + static CACHE: OnceLock>> = OnceLock::new(); + + let sem = SEM.get_or_init(|| tokio::sync::Semaphore::new(16)); + let cache = CACHE.get_or_init(|| Mutex::new(HashMap::new())); + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if let Some((value, last_used)) = guard.get_mut(url) { + if last_used.elapsed() < cache_ttl { + *last_used = Instant::now(); + return Ok(value.clone()); + } + } + } + } + + let _permit = sem.acquire().await.expect("probe semaphore closed"); + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if let Some((value, last_used)) = guard.get_mut(url) { + if last_used.elapsed() < cache_ttl { + *last_used = Instant::now(); + return Ok(value.clone()); + } + } + } + } + + let client = http_client(); + let mut req = client.get(url).timeout(timeout); + for (k, v) in headers { + req = req.header(*k, *v); + } + let text = req.send().await?.text().await?; + + if cache_ttl > Duration::ZERO { + if let Ok(mut guard) = cache.lock() { + if guard.len() >= MAX_ENTRIES { + evict_oldest(&mut guard); + } + guard.insert(url.to_string(), (text.clone(), Instant::now())); + } + } + + Ok(text) +} + +fn evict_oldest(map: &mut HashMap) { + let mut oldest_key = None; + let mut oldest_at = Instant::now(); + for (k, (_, last_used)) in map.iter() { + if oldest_key.is_none() || *last_used < oldest_at { + oldest_key = Some(k.clone()); + oldest_at = *last_used; + } + } + if let Some(key) = oldest_key { + map.remove(&key); + } +} + fn grpc_tls_config() -> Option { let mut tls = tonic::transport::ClientTlsConfig::new(); let mut configured = false; diff --git a/projection/Cargo.toml b/projection/Cargo.toml index cc2e4ff..7b58d02 100644 --- a/projection/Cargo.toml +++ b/projection/Cargo.toml @@ -29,8 +29,14 @@ uuid = { version = "1", features = ["v7", "serde"] } chrono = { version = "0.4", features = ["serde"] } futures = "0.3" axum = "0.7" +prost = "0.13" +tonic = { version = "0.12", default-features = false, features = ["codegen", "prost", "transport"] } v8 = { version = "0.106", optional = true } [dev-dependencies] tempfile = "3" tower = "0.5" + +[build-dependencies] +tonic-build = { version = "0.12", default-features = false, features = ["prost"] } +protoc-bin-vendored = "3" diff --git a/projection/build.rs b/projection/build.rs new file mode 100644 index 0000000..86f34c0 --- /dev/null +++ b/projection/build.rs @@ -0,0 +1,8 @@ +fn main() -> Result<(), Box> { + let protoc = protoc_bin_vendored::protoc_bin_path()?; + std::env::set_var("PROTOC", protoc); + + tonic_build::configure().compile_protos(&["proto/query.proto"], &["proto"])?; + + Ok(()) +} diff --git a/projection/proto/query.proto b/projection/proto/query.proto new file mode 100644 index 0000000..8a258d6 --- /dev/null +++ b/projection/proto/query.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package projection.gateway.v1; + +service QueryService { + rpc ExecuteQuery(QueryRequest) returns (QueryResponse); +} + +message QueryRequest { + string tenant_id = 1; + string view_type = 2; + string uqf = 3; +} + +message QueryResponse { + string json = 1; +} diff --git a/projection/src/config/settings.rs b/projection/src/config/settings.rs index ebe1f71..6b919c7 100644 --- a/projection/src/config/settings.rs +++ b/projection/src/config/settings.rs @@ -18,6 +18,7 @@ pub struct Settings { pub max_deliver: i64, pub consumer_mode: ConsumerMode, pub http_addr: String, + pub grpc_addr: String, pub storage_backoff_ms: u64, pub storage_backoff_max_ms: u64, } @@ -47,6 +48,7 @@ impl Default for Settings { max_deliver: 10, consumer_mode: ConsumerMode::Single, http_addr: "0.0.0.0:8080".to_string(), + grpc_addr: "0.0.0.0:9090".to_string(), storage_backoff_ms: 50, storage_backoff_max_ms: 2_000, } @@ -181,6 +183,12 @@ impl Settings { } } + if let Ok(addr) = std::env::var("PROJECTION_GRPC_ADDR") { + if !addr.trim().is_empty() { + self.grpc_addr = addr; + } + } + if let Ok(ms) = std::env::var("PROJECTION_STORAGE_BACKOFF_MS") { if let Ok(value) = ms.parse() { self.storage_backoff_ms = value; @@ -210,6 +218,12 @@ impl Settings { if self.durable_name.is_empty() { return Err("Durable name is required".to_string()); } + if self.http_addr.trim().is_empty() { + return Err("HTTP addr is required".to_string()); + } + if self.grpc_addr.trim().is_empty() { + return Err("gRPC addr is required".to_string()); + } Ok(()) } } diff --git a/projection/src/grpc.rs b/projection/src/grpc.rs new file mode 100644 index 0000000..52913e0 --- /dev/null +++ b/projection/src/grpc.rs @@ -0,0 +1,181 @@ +use crate::config::Settings; +use crate::query::{QueryError, QueryRequest, QueryService}; +use crate::tenant_placement::TenantPlacement; +use crate::types::{ProjectionError, TenantId, ViewType}; +use crate::ProjectionManifest; + +pub mod proto { + tonic::include_proto!("projection.gateway.v1"); +} + +#[derive(Clone)] +pub struct GrpcQueryService { + placement: TenantPlacement, + manifest: ProjectionManifest, + query: QueryService, +} + +impl GrpcQueryService { + pub fn new( + placement: TenantPlacement, + manifest: ProjectionManifest, + query: QueryService, + ) -> Self { + Self { + placement, + manifest, + query, + } + } +} + +#[tonic::async_trait] +impl proto::query_service_server::QueryService for GrpcQueryService { + async fn execute_query( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let md_tenant = request + .metadata() + .get(shared::HEADER_X_TENANT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + + let req = request.into_inner(); + let tenant_id = req.tenant_id.trim().to_string(); + if tenant_id.is_empty() { + return Err(tonic::Status::invalid_argument("tenant_id is required")); + } + if let Some(md_tenant) = md_tenant.as_deref() { + if md_tenant != tenant_id { + return Err(tonic::Status::permission_denied("tenant mismatch")); + } + } + + let tenant_id = TenantId::new(tenant_id); + + if self.placement.is_draining(&tenant_id) { + return Err(tonic::Status::unavailable("tenant is draining")); + } + if !self.placement.is_hosted(&tenant_id) { + return Err(tonic::Status::permission_denied("tenant not hosted")); + } + + let view_type_raw = req.view_type.trim().to_string(); + if view_type_raw.is_empty() { + return Err(tonic::Status::invalid_argument("view_type is required")); + } + let view_type = ViewType::new(view_type_raw.clone()); + if self.manifest.get(&view_type).is_none() { + return Err(tonic::Status::not_found("unknown view type")); + } + + let uqf = req.uqf; + if uqf.trim().is_empty() { + return Err(tonic::Status::invalid_argument("uqf is required")); + } + + let request = QueryRequest { + tenant_id, + view_type, + uqf, + }; + + let response = self.query.query(request).map_err(map_query_error)?; + let json = + serde_json::to_string(&response).map_err(|e| tonic::Status::internal(e.to_string()))?; + + Ok(tonic::Response::new(proto::QueryResponse { json })) + } +} + +fn map_query_error(err: QueryError) -> tonic::Status { + match err { + QueryError::InvalidQuery(e) => tonic::Status::invalid_argument(e), + QueryError::Execution(e) => tonic::Status::internal(e), + } +} + +pub async fn serve( + settings: Settings, + placement: TenantPlacement, + manifest: ProjectionManifest, + query: QueryService, + shutdown: std::sync::Arc, +) -> Result<(), ProjectionError> { + let addr: std::net::SocketAddr = settings + .grpc_addr + .parse::() + .map_err(|e| ProjectionError::StreamError(e.to_string()))?; + + tonic::transport::Server::builder() + .add_service(proto::query_service_server::QueryServiceServer::new( + GrpcQueryService::new(placement, manifest, query), + )) + .serve_with_shutdown(addr, async move { shutdown.notified().await }) + .await + .map_err(|e| ProjectionError::StreamError(e.to_string()))?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::proto::query_service_server::QueryService as QueryServiceGrpc; + use super::*; + use crate::storage::KvClient; + use crate::types::{CheckpointKey, StreamSequence, ViewId, ViewKey}; + use serde_json::json; + + #[tokio::test] + async fn rejects_tenant_not_hosted() { + let storage = KvClient::in_memory(); + let tenant_a = TenantId::new("tenant-a"); + let view_type = ViewType::new("User"); + let cp = CheckpointKey::new(&tenant_a, &view_type); + let key = ViewKey::new(&tenant_a, &view_type, &ViewId::new("u1")); + storage + .commit_view_and_checkpoint(&key, &json!({"id":"u1"}), &cp, 1 as StreamSequence) + .unwrap(); + + let query = QueryService::new(storage); + let mut manifest = ProjectionManifest::new(); + manifest.register(crate::project::ProjectionDefinition { + view_type: view_type.clone(), + project_program: "unused".to_string(), + }); + + let placement = TenantPlacement::with_hosted(Some(vec!["tenant-a".to_string()])); + let svc = GrpcQueryService::new(placement, manifest, query); + + let request = proto::QueryRequest { + tenant_id: "tenant-b".to_string(), + view_type: "User".to_string(), + uqf: "{}".to_string(), + }; + let err = QueryServiceGrpc::execute_query(&svc, tonic::Request::new(request)) + .await + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn rejects_unknown_view_type() { + let query = QueryService::new(KvClient::in_memory()); + let placement = TenantPlacement::with_hosted(None); + let manifest = ProjectionManifest::new(); + let svc = GrpcQueryService::new(placement, manifest, query); + + let request = proto::QueryRequest { + tenant_id: "tenant-a".to_string(), + view_type: "Missing".to_string(), + uqf: "{}".to_string(), + }; + let err = QueryServiceGrpc::execute_query(&svc, tonic::Request::new(request)) + .await + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::NotFound); + } +} diff --git a/projection/src/http/mod.rs b/projection/src/http/mod.rs index bce22fb..17c1c35 100644 --- a/projection/src/http/mod.rs +++ b/projection/src/http/mod.rs @@ -289,7 +289,7 @@ fn tenant_from_headers( headers: &HeaderMap, ) -> Result { let header_value = headers - .get("x-tenant-id") + .get(shared::HEADER_X_TENANT_ID) .and_then(|v| v.to_str().ok()) .map(|s| s.trim()) .unwrap_or(""); diff --git a/projection/src/lib.rs b/projection/src/lib.rs index b5fb1da..efce903 100644 --- a/projection/src/lib.rs +++ b/projection/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +pub mod grpc; pub mod http; pub mod observability; pub mod project; diff --git a/projection/src/main.rs b/projection/src/main.rs index 3d80891..8c6e5db 100644 --- a/projection/src/main.rs +++ b/projection/src/main.rs @@ -54,10 +54,27 @@ async fn serve() { } }; + let grpc_manifest = http_state.manifest.clone(); + let grpc_query = http_state.query.clone(); + let http_shutdown = shutdown.clone(); let http_task = tokio::spawn(async move { projection::http::serve(http_state, http_shutdown).await }); + let grpc_shutdown = shutdown.clone(); + let grpc_settings = settings.clone(); + let grpc_placement = tenant_placement.clone(); + let grpc_task = tokio::spawn(async move { + projection::grpc::serve( + grpc_settings, + grpc_placement, + grpc_manifest, + grpc_query, + grpc_shutdown, + ) + .await + }); + let signal_shutdown = shutdown.clone(); let signal_ready = ready.clone(); let signal_draining = draining.clone(); @@ -103,6 +120,7 @@ async fn serve() { shutdown.notify_waiters(); let _ = http_task.await; + let _ = grpc_task.await; match worker_result { Ok(Ok(())) => {} diff --git a/projection/src/stream/jetstream.rs b/projection/src/stream/jetstream.rs index dbccb43..630a23f 100644 --- a/projection/src/stream/jetstream.rs +++ b/projection/src/stream/jetstream.rs @@ -2,7 +2,7 @@ use crate::config::Settings; use crate::types::ProjectionError; use async_nats::jetstream::{ self, consumer::pull::Config as PullConfig, consumer::AckPolicy, consumer::DeliverPolicy, - consumer::ReplayPolicy, + consumer::ReplayPolicy, stream::Config as StreamConfig, }; #[derive(Debug, Clone)] @@ -24,7 +24,7 @@ impl JetStreamClient { .subject_filters .first() .cloned() - .unwrap_or_else(|| "tenant.*.aggregate.*.*".to_string()); + .unwrap_or_else(|| shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()); let options = ConsumerOptions { durable_name: settings.durable_name.clone(), @@ -45,20 +45,32 @@ impl JetStreamClient { let jetstream = jetstream::new(client); - let stream = jetstream - .get_stream(&settings.stream_name) + let expected = stream_policy_config(&settings.stream_name); + let mut stream = jetstream + .get_or_create_stream(expected.clone()) .await - .map_err(|e| ProjectionError::StreamError(format!("Stream not found: {}", e)))?; + .map_err(|e| ProjectionError::StreamError(format!("Stream error: {}", e)))?; + let info = stream + .info() + .await + .map_err(|e| ProjectionError::StreamError(format!("Stream info error: {}", e)))?; + validate_stream_config(&expected, &info.config)?; + + let policy = shared::consumer_policy_from_parts( + settings.ack_timeout_ms, + settings.max_in_flight, + settings.max_deliver, + ); let consumer_config = PullConfig { durable_name: Some(options.durable_name.clone()), deliver_policy: options.deliver_policy, ack_policy: AckPolicy::Explicit, - ack_wait: std::time::Duration::from_millis(settings.ack_timeout_ms), + ack_wait: policy.ack_wait, filter_subject: options.filter_subject, replay_policy: ReplayPolicy::Instant, - max_ack_pending: settings.max_in_flight as i64, - max_deliver: settings.max_deliver, + max_ack_pending: policy.max_ack_pending, + max_deliver: policy.max_deliver, ..Default::default() }; @@ -88,3 +100,43 @@ impl JetStreamClient { Ok(info.state.last_sequence) } } + +fn stream_policy_config(name: &str) -> StreamConfig { + let policy = shared::stream_policy_defaults( + name.to_string(), + vec![shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()], + ); + StreamConfig { + name: policy.name, + subjects: policy.subjects, + max_messages: policy.max_messages, + max_bytes: policy.max_bytes, + max_age: policy.max_age, + duplicate_window: policy.duplicate_window, + ..Default::default() + } +} + +fn validate_stream_config( + expected: &StreamConfig, + actual: &StreamConfig, +) -> Result<(), ProjectionError> { + let expected = shared::stream_policy_from_parts( + expected.name.as_str(), + expected.subjects.clone(), + expected.max_messages, + expected.max_bytes, + expected.max_age, + expected.duplicate_window, + ); + let actual = shared::stream_policy_from_parts( + actual.name.as_str(), + actual.subjects.clone(), + actual.max_messages, + actual.max_bytes, + actual.max_age, + actual.duplicate_window, + ); + shared::validate_stream_policy(&expected, &actual) + .map_err(|e| ProjectionError::StreamError(e.to_string())) +} diff --git a/projection/src/stream/mod.rs b/projection/src/stream/mod.rs index 4371179..393fbe4 100644 --- a/projection/src/stream/mod.rs +++ b/projection/src/stream/mod.rs @@ -101,7 +101,7 @@ async fn run_projection_per_view_with_options( .subject_filters .first() .cloned() - .unwrap_or_else(|| "tenant.*.aggregate.*.*".to_string()); + .unwrap_or_else(|| shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()); let shutdown = options.shutdown.clone(); let ready = options.ready.clone(); @@ -220,7 +220,7 @@ pub async fn run_projection_with_options( .consumer_filter_subject .clone() .or_else(|| settings.subject_filters.first().cloned()) - .unwrap_or_else(|| "tenant.*.aggregate.*.*".to_string()); + .unwrap_or_else(|| shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()); let deliver_policy = options .consumer_deliver_policy .unwrap_or(DeliverPolicy::All); @@ -301,7 +301,7 @@ pub async fn run_projection_with_options( let sequence = info.stream_sequence; let delivered = info.delivered; - let envelope: EventEnvelope = match serde_json::from_slice(&msg.payload) { + let mut envelope: EventEnvelope = match serde_json::from_slice(&msg.payload) { Ok(e) => e, Err(e) => { tracing::error!(error = %e, "Failed to decode event envelope"); @@ -310,6 +310,53 @@ pub async fn run_projection_with_options( } }; + if let Some(headers) = msg.headers.as_ref() { + if envelope.correlation_id.is_none() { + let correlation_id = headers + .get(shared::NATS_HEADER_CORRELATION_ID) + .or_else(|| headers.get(shared::HEADER_X_CORRELATION_ID)) + .map(|v| v.to_string()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + if let Some(correlation_id) = correlation_id { + envelope.correlation_id = Some(shared::CorrelationId::new(correlation_id)); + } + } + + if envelope.traceparent.is_none() { + let traceparent = headers + .get(shared::HEADER_TRACEPARENT) + .map(|v| v.to_string()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + if let Some(traceparent) = traceparent { + let normalized = shared::normalize_traceparent(Some(&traceparent)); + envelope.traceparent = Some(normalized); + } + } + + if envelope.trace_id.is_none() { + let trace_id = headers + .get(shared::HEADER_TRACE_ID) + .map(|v| v.to_string()) + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .map(shared::TraceId::new) + .filter(|t| t.is_valid_hex_32()); + if let Some(trace_id) = trace_id { + envelope.trace_id = Some(trace_id); + } + } + } + + if envelope.trace_id.is_none() { + if let Some(traceparent) = envelope.traceparent.as_deref() { + if let Some(trace_id) = shared::trace_id_from_traceparent(traceparent) { + envelope.trace_id = Some(shared::TraceId::new(trace_id.to_string())); + } + } + } + let tenant_id = resolve_tenant_id(&settings, &envelope); if let Some(filter) = &options.tenant_filter { @@ -460,7 +507,7 @@ pub async fn rebuild_view( Uuid::now_v7() ); let filter_subject = if tenant_id.is_empty() { - "tenant.*.aggregate.*.*".to_string() + shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string() } else { format!("tenant.{}.aggregate.*.*", tenant_id.as_str()) }; @@ -499,7 +546,7 @@ pub async fn backfill_to_tail( ) -> Result<(), ProjectionError> { let durable_name = format!("{}_backfill_{}", settings.durable_name, Uuid::now_v7()); let filter_subject = if tenant_id.is_empty() { - "tenant.*.aggregate.*.*".to_string() + shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string() } else { format!("tenant.{}.aggregate.*.*", tenant_id.as_str()) }; diff --git a/projection/src/tenant_placement.rs b/projection/src/tenant_placement.rs index b4d4f22..8476505 100644 --- a/projection/src/tenant_placement.rs +++ b/projection/src/tenant_placement.rs @@ -23,6 +23,22 @@ pub struct TenantPlacementSnapshot { } impl TenantPlacement { + pub fn with_hosted(hosted: Option>) -> Self { + let hosted = hosted.map(|items| { + items + .into_iter() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect::>() + }); + Self { + inner: Arc::new(RwLock::new(Inner { + hosted, + draining: HashSet::new(), + })), + } + } + pub fn load(settings: &Settings) -> Result { let hosted = hosted_tenants_from_settings(settings)?; Ok(Self { diff --git a/runner/build.rs b/runner/build.rs index 1c9dbee..1d67486 100644 --- a/runner/build.rs +++ b/runner/build.rs @@ -2,7 +2,8 @@ fn main() -> Result<(), Box> { let protoc = protoc_bin_vendored::protoc_bin_path()?; std::env::set_var("PROTOC", protoc); - tonic_build::configure().compile_protos(&["proto/aggregate.proto"], &["proto"])?; + tonic_build::configure() + .compile_protos(&["proto/aggregate.proto", "proto/admin.proto"], &["proto"])?; Ok(()) } diff --git a/runner/proto/admin.proto b/runner/proto/admin.proto new file mode 100644 index 0000000..8896e11 --- /dev/null +++ b/runner/proto/admin.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package runner.admin.v1; + +service RunnerAdmin { + rpc Drain(DrainRequest) returns (AdminResponse); + rpc DrainStatus(DrainStatusRequest) returns (AdminResponse); + rpc Reload(ReloadRequest) returns (AdminResponse); +} + +message DrainRequest { + string tenant_id = 1; + uint64 wait_ms = 2; +} + +message DrainStatusRequest { + string tenant_id = 1; +} + +message ReloadRequest {} + +message AdminResponse { + uint32 http_status = 1; + string json = 2; +} diff --git a/runner/src/config/settings.rs b/runner/src/config/settings.rs index 153ae70..5ebd1e5 100644 --- a/runner/src/config/settings.rs +++ b/runner/src/config/settings.rs @@ -45,6 +45,7 @@ pub struct Settings { pub effect_retry_backoff_ms: u64, pub http_addr: String, + pub grpc_addr: String, pub test_saga_crash_after_commit: bool, pub test_effect_crash_after_dedupe_before_ack: bool, @@ -79,8 +80,12 @@ impl Default for Settings { workflow_commands_stream: "WORKFLOW_COMMANDS".to_string(), workflow_events_stream: "WORKFLOW_EVENTS".to_string(), - saga_trigger_subject_filters: vec!["tenant.*.aggregate.*.*".to_string()], - effect_command_subject_filters: vec!["tenant.*.effect.*.*".to_string()], + saga_trigger_subject_filters: vec![ + shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string() + ], + effect_command_subject_filters: vec![ + shared::NATS_SUBJECT_EFFECT_COMMANDS_ALL.to_string() + ], consumer_durable_prefix: "runner".to_string(), deliver_group: None, @@ -104,6 +109,7 @@ impl Default for Settings { effect_retry_backoff_ms: 250, http_addr: "0.0.0.0:8080".to_string(), + grpc_addr: "0.0.0.0:9091".to_string(), test_saga_crash_after_commit: false, test_effect_crash_after_dedupe_before_ack: false, @@ -350,6 +356,11 @@ impl Settings { self.http_addr = addr; } } + if let Ok(addr) = std::env::var("RUNNER_GRPC_ADDR") { + if !addr.trim().is_empty() { + self.grpc_addr = addr; + } + } if let Ok(v) = std::env::var("RUNNER_TEST_SAGA_CRASH_AFTER_COMMIT") { self.test_saga_crash_after_commit = @@ -375,6 +386,12 @@ impl Settings { if self.aggregate_events_stream.is_empty() { return Err("Aggregate events stream name is required".to_string()); } + if self.http_addr.trim().is_empty() { + return Err("HTTP addr is required".to_string()); + } + if self.grpc_addr.trim().is_empty() { + return Err("gRPC addr is required".to_string()); + } if matches!(self.mode, RunnerMode::Saga | RunnerMode::Combined) && self.saga_trigger_subject_filters.is_empty() { @@ -388,6 +405,9 @@ impl Settings { if self.consumer_durable_prefix.trim().is_empty() { return Err("Consumer durable prefix is required".to_string()); } + if self.deliver_group.is_some() { + return Err("deliver_group is not supported with pull consumers".to_string()); + } if self.max_in_flight == 0 { return Err("Max in-flight must be > 0".to_string()); } @@ -479,6 +499,17 @@ mod tests { std::env::remove_var("RUNNER_TENANT_ALLOWLIST"); } + #[test] + fn deliver_group_is_rejected_with_pull_consumers() { + let settings = Settings { + saga_manifest_path: "runner/config/sagas.yaml".to_string(), + effects_manifest_path: "runner/config/effects.yaml".to_string(), + deliver_group: Some("g1".to_string()), + ..Default::default() + }; + assert!(settings.validate().is_err()); + } + #[test] fn settings_validation_catches_missing_required() { let settings = Settings { diff --git a/runner/src/effects/worker.rs b/runner/src/effects/worker.rs index 4b4e630..2201db2 100644 --- a/runner/src/effects/worker.rs +++ b/runner/src/effects/worker.rs @@ -221,7 +221,7 @@ async fn run_effect_worker_single( .effect_command_subject_filters .first() .cloned() - .unwrap_or_else(|| "tenant.*.effect.*.*".to_string()); + .unwrap_or_else(|| shared::NATS_SUBJECT_EFFECT_COMMANDS_ALL.to_string()); let consumer = jetstream .effect_command_consumer( @@ -326,7 +326,7 @@ async fn run_effect_worker_for_tenant( draining: Arc, ) -> Result<(), RunnerError> { let durable_name = format!("{}_effects_{}", settings.consumer_durable_prefix, tenant); - let filter_subject = format!("tenant.{}.effect.*.*", tenant); + let filter_subject = shared::nats_filter_subject_effect_for_tenant(&tenant); let consumer = jetstream .effect_command_consumer( @@ -467,11 +467,7 @@ enum ProcessDecision { } trait EffectResultPublisher: Send + Sync { - fn publish( - &self, - subject: String, - result: EffectResultEnvelope, - ) -> BoxFuture<'static, Result<(), RunnerError>>; + fn publish(&self, result: EffectResultEnvelope) -> BoxFuture<'static, Result<(), RunnerError>>; } #[derive(Clone)] @@ -486,13 +482,9 @@ impl JetStreamPublisher { } impl EffectResultPublisher for JetStreamPublisher { - fn publish( - &self, - subject: String, - result: EffectResultEnvelope, - ) -> BoxFuture<'static, Result<(), RunnerError>> { + fn publish(&self, result: EffectResultEnvelope) -> BoxFuture<'static, Result<(), RunnerError>> { let jetstream = self.jetstream.clone(); - Box::pin(async move { jetstream.publish_effect_result(subject, &result).await }) + Box::pin(async move { jetstream.publish_effect_result(&result).await }) } } @@ -564,14 +556,7 @@ async fn publish_and_mark( result.metadata.trace_id = cmd.metadata.trace_id.clone(); } - let subject = format!( - "tenant.{}.effect_result.{}.{}", - cmd.tenant_id.as_str(), - cmd.effect_name.as_str(), - cmd.command_id.as_str() - ); - - if let Err(e) = publisher.publish(subject, result).await { + if let Err(e) = publisher.publish(result).await { metrics.inc_effect_publish_failed(); return Err(e); } @@ -690,7 +675,6 @@ mod tests { impl EffectResultPublisher for FakePublisher { fn publish( &self, - _subject: String, _result: EffectResultEnvelope, ) -> BoxFuture<'static, Result<(), RunnerError>> { let fail = self.fail; diff --git a/runner/src/gateway/mod.rs b/runner/src/gateway/mod.rs index 4cbd24c..1007ad9 100644 --- a/runner/src/gateway/mod.rs +++ b/runner/src/gateway/mod.rs @@ -1,6 +1,6 @@ -pub const TENANT_ID_METADATA_KEY: &str = "x-tenant-id"; -pub const CORRELATION_ID_METADATA_KEY: &str = "x-correlation-id"; -pub const TRACEPARENT_METADATA_KEY: &str = "traceparent"; +pub const TENANT_ID_METADATA_KEY: &str = shared::HEADER_X_TENANT_ID; +pub const CORRELATION_ID_METADATA_KEY: &str = shared::HEADER_X_CORRELATION_ID; +pub const TRACEPARENT_METADATA_KEY: &str = shared::HEADER_TRACEPARENT; pub mod proto { tonic::include_proto!("aggregate.gateway.v1"); @@ -47,7 +47,7 @@ impl GatewayClient { let correlation_id = grpc_request .get_ref() .metadata - .get("x-correlation-id") + .get(shared::HEADER_X_CORRELATION_ID) .or_else(|| grpc_request.get_ref().metadata.get("correlation_id")) .map(|s| s.trim()) .filter(|s| !s.is_empty()) @@ -68,7 +68,7 @@ impl GatewayClient { let traceparent = grpc_request .get_ref() .metadata - .get("traceparent") + .get(shared::HEADER_TRACEPARENT) .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.to_string()) @@ -78,10 +78,8 @@ impl GatewayClient { .metadata .get("trace_id") .map(|s| s.trim()) - .filter(|s| s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit())) - .map(|trace_id| { - let span_id = uuid::Uuid::new_v4().simple().to_string()[..16].to_string(); - format!("00-{trace_id}-{span_id}-01") + .and_then(|trace_id| { + shared::traceparent_from_trace_id(&shared::TraceId::new(trace_id)) }) }); if let Some(traceparent) = traceparent { diff --git a/runner/src/grpc_admin.rs b/runner/src/grpc_admin.rs new file mode 100644 index 0000000..c1ff603 --- /dev/null +++ b/runner/src/grpc_admin.rs @@ -0,0 +1,292 @@ +use crate::http::AppState; +use axum::http::StatusCode; +use serde_json::json; +use std::sync::Arc; +use std::time::Duration; + +pub mod proto { + tonic::include_proto!("runner.admin.v1"); +} + +#[derive(Clone)] +pub struct RunnerAdminService { + state: Arc, +} + +impl RunnerAdminService { + pub fn new(state: Arc) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl proto::runner_admin_server::RunnerAdmin for RunnerAdminService { + async fn drain( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let md_tenant = request + .metadata() + .get(shared::HEADER_X_TENANT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + + let req = request.into_inner(); + let tenant_id = req.tenant_id.trim().to_string(); + if tenant_id.is_empty() { + self.state.start_draining(); + return Ok(tonic::Response::new(proto::AdminResponse { + http_status: StatusCode::OK.as_u16() as u32, + json: json!({ "ok": true, "draining": true }).to_string(), + })); + } + + if let Some(md_tenant) = md_tenant.as_deref() { + if md_tenant != tenant_id { + return Err(tonic::Status::permission_denied("tenant mismatch")); + } + } + + self.state.tenant_gate.start_draining(&tenant_id); + let wait_ms = req.wait_ms; + if wait_ms > 0 { + let deadline = tokio::time::Instant::now() + Duration::from_millis(wait_ms); + loop { + let status = tenant_drain_state(&self.state, &tenant_id); + if status.drained { + break; + } + if tokio::time::Instant::now() >= deadline { + break; + } + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + let _ = self + .state + .tenant_gate + .wait_inflight_zero(&tenant_id, remaining.min(Duration::from_millis(250))) + .await; + tokio::time::sleep(Duration::from_millis(25)).await; + } + } + + let resp = tenant_drain_status(&self.state, &tenant_id); + Ok(tonic::Response::new(resp)) + } + + async fn drain_status( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let md_tenant = request + .metadata() + .get(shared::HEADER_X_TENANT_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + + let tenant_id = request.into_inner().tenant_id.trim().to_string(); + if tenant_id.is_empty() { + return Ok(tonic::Response::new(proto::AdminResponse { + http_status: StatusCode::BAD_REQUEST.as_u16() as u32, + json: json!({ "ok": false, "error": "tenant_id required" }).to_string(), + })); + } + if let Some(md_tenant) = md_tenant.as_deref() { + if md_tenant != tenant_id { + return Err(tonic::Status::permission_denied("tenant mismatch")); + } + } + + Ok(tonic::Response::new(tenant_drain_status( + &self.state, + &tenant_id, + ))) + } + + async fn reload( + &self, + _request: tonic::Request, + ) -> Result, tonic::Status> { + self.state.notify_reload(); + Ok(tonic::Response::new(proto::AdminResponse { + http_status: StatusCode::OK.as_u16() as u32, + json: json!({ "ok": true }).to_string(), + })) + } +} + +fn tenant_drain_status(state: &AppState, tenant_id: &str) -> proto::AdminResponse { + let status = tenant_drain_state(state, tenant_id); + let code = if status.drained { + StatusCode::OK + } else { + StatusCode::ACCEPTED + }; + + proto::AdminResponse { + http_status: code.as_u16() as u32, + json: json!({ + "ok": true, + "tenant_id": tenant_id, + "draining_tenant": state.tenant_gate.is_draining(tenant_id), + "assigned": state.tenant_gate.is_assigned(tenant_id), + "in_flight": status.in_flight, + "outbox_items": status.outbox_items, + "drained": status.drained + }) + .to_string(), + } +} + +struct TenantDrainState { + in_flight: usize, + outbox_items: usize, + drained: bool, +} + +fn tenant_drain_state(state: &AppState, tenant_id: &str) -> TenantDrainState { + let in_flight = state.tenant_gate.inflight_count(tenant_id); + let outbox_items = state + .storage + .list_outbox_prefix(&crate::types::TenantId::new(tenant_id.to_string()), 50_000) + .map(|v| v.len()) + .unwrap_or(0); + TenantDrainState { + in_flight, + outbox_items, + drained: in_flight == 0 && outbox_items == 0, + } +} + +pub async fn serve( + addr: std::net::SocketAddr, + state: Arc, + shutdown: impl std::future::Future + Send + 'static, +) -> Result<(), crate::types::RunnerError> { + tonic::transport::Server::builder() + .add_service(proto::runner_admin_server::RunnerAdminServer::new( + RunnerAdminService::new(state), + )) + .serve_with_shutdown(addr, shutdown) + .await + .map_err(|e| crate::types::RunnerError::StreamError(e.to_string()))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tenant_placement::TenantGate; + use crate::types::{ + CommandId, EffectCommandEnvelope, EffectName, MessageMetadata, TenantId, WorkId, WorkItem, + }; + use std::sync::atomic::AtomicBool; + + #[tokio::test] + async fn rejects_tenant_mismatch() { + let metrics = Arc::new(crate::observability::Metrics::default()); + let draining = Arc::new(AtomicBool::new(false)); + let tenant_gate = Arc::new(TenantGate::new(None)); + let storage = crate::storage::KvClient::in_memory(); + let reload = Arc::new(tokio::sync::Notify::new()); + let state = Arc::new(crate::http::AppState::new( + crate::Settings::default(), + draining, + tenant_gate, + metrics, + storage, + reload, + )); + + let svc = RunnerAdminService::new(state); + let mut req = tonic::Request::new(proto::DrainStatusRequest { + tenant_id: "tenant-a".to_string(), + }); + req.metadata_mut().insert( + shared::HEADER_X_TENANT_ID, + tonic::metadata::MetadataValue::try_from("tenant-b").unwrap(), + ); + let err = proto::runner_admin_server::RunnerAdmin::drain_status(&svc, req) + .await + .unwrap_err(); + assert_eq!(err.code(), tonic::Code::PermissionDenied); + } + + #[tokio::test] + async fn drain_status_reflects_outbox_backlog() { + let metrics = Arc::new(crate::observability::Metrics::default()); + let draining = Arc::new(AtomicBool::new(false)); + let tenant_gate = Arc::new(TenantGate::new(None)); + let storage = crate::storage::KvClient::in_memory(); + let reload = Arc::new(tokio::sync::Notify::new()); + let state = Arc::new(crate::http::AppState::new( + crate::Settings::default(), + draining, + tenant_gate, + metrics, + storage.clone(), + reload, + )); + + let tenant = TenantId::new("tenant-a"); + let work_id = WorkId::new_v7(); + let item = WorkItem::EffectCommand(EffectCommandEnvelope { + tenant_id: tenant.clone(), + command_id: CommandId::new("c1"), + effect_name: EffectName::new("noop"), + payload: serde_json::json!({"ok": true}), + metadata: MessageMetadata::default(), + }); + let _key = storage + .put_outbox_item(&tenant, "effect", &work_id, &item) + .unwrap(); + + let svc = RunnerAdminService::new(state); + let req = tonic::Request::new(proto::DrainStatusRequest { + tenant_id: tenant.as_str().to_string(), + }); + let resp = proto::runner_admin_server::RunnerAdmin::drain_status(&svc, req) + .await + .unwrap() + .into_inner(); + + assert_eq!(resp.http_status, 202); + let json: serde_json::Value = serde_json::from_str(&resp.json).unwrap(); + assert_eq!(json["tenant_id"], tenant.as_str()); + assert_eq!(json["drained"], false); + } + + #[tokio::test] + async fn drain_wait_zero_returns_drained_when_no_backlog() { + let metrics = Arc::new(crate::observability::Metrics::default()); + let draining = Arc::new(AtomicBool::new(false)); + let tenant_gate = Arc::new(TenantGate::new(None)); + let storage = crate::storage::KvClient::in_memory(); + let reload = Arc::new(tokio::sync::Notify::new()); + let state = Arc::new(crate::http::AppState::new( + crate::Settings::default(), + draining, + tenant_gate, + metrics, + storage, + reload, + )); + + let svc = RunnerAdminService::new(state); + let req = tonic::Request::new(proto::DrainRequest { + tenant_id: "tenant-a".to_string(), + wait_ms: 0, + }); + let resp = proto::runner_admin_server::RunnerAdmin::drain(&svc, req) + .await + .unwrap() + .into_inner(); + assert_eq!(resp.http_status, 200); + let json: serde_json::Value = serde_json::from_str(&resp.json).unwrap(); + assert_eq!(json["tenant_id"], "tenant-a"); + assert_eq!(json["drained"], true); + } +} diff --git a/runner/src/http/mod.rs b/runner/src/http/mod.rs index 74f2e97..04194ec 100644 --- a/runner/src/http/mod.rs +++ b/runner/src/http/mod.rs @@ -19,7 +19,7 @@ use std::time::Duration; pub struct AppState { pub settings: Settings, draining: Arc, - tenant_gate: Arc, + pub(crate) tenant_gate: Arc, pub metrics: Arc, pub storage: KvClient, reload: Arc, diff --git a/runner/src/lib.rs b/runner/src/lib.rs index 07daf84..6ea1f21 100644 --- a/runner/src/lib.rs +++ b/runner/src/lib.rs @@ -1,6 +1,7 @@ pub mod config; pub mod effects; pub mod gateway; +pub mod grpc_admin; pub mod http; pub mod observability; pub mod outbox; diff --git a/runner/src/main.rs b/runner/src/main.rs index 648f818..55cec80 100644 --- a/runner/src/main.rs +++ b/runner/src/main.rs @@ -83,6 +83,16 @@ async fn serve() { .await }); + let grpc_addr: std::net::SocketAddr = settings.grpc_addr.parse().unwrap(); + let grpc_shutdown = shutdown.clone(); + let grpc_state = state.clone(); + let grpc_task = tokio::spawn(async move { + runner::grpc_admin::serve(grpc_addr, grpc_state, async move { + grpc_shutdown.notified().await + }) + .await + }); + let signal_shutdown = shutdown.clone(); let signal_draining = draining.clone(); tokio::spawn(async move { @@ -278,6 +288,7 @@ async fn serve() { draining.store(true, Ordering::Relaxed); shutdown.notify_waiters(); let _ = http_task.await; + let _ = grpc_task.await; if let Some(e) = failed { tracing::error!(error = %e, "Runner terminated with error"); diff --git a/runner/src/saga/worker.rs b/runner/src/saga/worker.rs index 85c1ba3..62d4c08 100644 --- a/runner/src/saga/worker.rs +++ b/runner/src/saga/worker.rs @@ -225,7 +225,7 @@ async fn run_saga_worker_single( .saga_trigger_subject_filters .first() .cloned() - .unwrap_or_else(|| "tenant.*.aggregate.*.*".to_string()); + .unwrap_or_else(|| shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()); let consumer = jetstream .saga_trigger_consumer( @@ -296,7 +296,7 @@ async fn run_saga_worker_for_tenant( draining: Arc, ) -> Result<(), RunnerError> { let durable_name = format!("{}_saga_{}", settings.consumer_durable_prefix, tenant); - let filter_subject = format!("tenant.{}.aggregate.*.*", tenant); + let filter_subject = shared::nats_filter_subject_aggregate_for_tenant(&tenant); let consumer = jetstream .saga_trigger_consumer( diff --git a/runner/src/stream/jetstream.rs b/runner/src/stream/jetstream.rs index a956312..34498d0 100644 --- a/runner/src/stream/jetstream.rs +++ b/runner/src/stream/jetstream.rs @@ -31,19 +31,15 @@ impl JetStreamClient { let jetstream = jetstream::new(client); - let aggregate_events_subjects = if settings.saga_trigger_subject_filters.is_empty() { - vec!["tenant.*.aggregate.*.*".to_string()] - } else { - settings.saga_trigger_subject_filters.clone() - }; + let aggregate_events_subjects = vec![shared::NATS_SUBJECT_AGGREGATE_EVENTS_ALL.to_string()]; let workflow_commands_subjects = vec![ - "tenant.*.effect.*.*".to_string(), - "tenant.*.workflow.*.*".to_string(), + shared::NATS_SUBJECT_EFFECT_COMMANDS_ALL.to_string(), + shared::NATS_SUBJECT_WORKFLOW_COMMANDS_ALL.to_string(), ]; let workflow_events_subjects = vec![ - "tenant.*.effect_result.*.*".to_string(), - "tenant.*.workflow_event.*.*".to_string(), + shared::NATS_SUBJECT_EFFECT_RESULTS_ALL.to_string(), + shared::NATS_SUBJECT_WORKFLOW_EVENTS_ALL.to_string(), ]; let mut last_err = None; @@ -83,15 +79,20 @@ impl JetStreamClient { settings: &Settings, options: ConsumerOptions, ) -> Result { + let policy = shared::consumer_policy_from_parts( + settings.ack_timeout_ms, + settings.max_in_flight, + settings.max_deliver, + ); let consumer_config = PullConfig { durable_name: Some(options.durable_name.clone()), deliver_policy: options.deliver_policy, ack_policy: AckPolicy::Explicit, - ack_wait: std::time::Duration::from_millis(settings.ack_timeout_ms), + ack_wait: policy.ack_wait, filter_subject: options.filter_subject, replay_policy: ReplayPolicy::Instant, - max_ack_pending: settings.max_in_flight as i64, - max_deliver: settings.max_deliver, + max_ack_pending: policy.max_ack_pending, + max_deliver: policy.max_deliver, ..Default::default() }; @@ -106,15 +107,20 @@ impl JetStreamClient { settings: &Settings, options: ConsumerOptions, ) -> Result { + let policy = shared::consumer_policy_from_parts( + settings.ack_timeout_ms, + settings.max_in_flight, + settings.max_deliver, + ); let consumer_config = PullConfig { durable_name: Some(options.durable_name.clone()), deliver_policy: options.deliver_policy, ack_policy: AckPolicy::Explicit, - ack_wait: std::time::Duration::from_millis(settings.ack_timeout_ms), + ack_wait: policy.ack_wait, filter_subject: options.filter_subject, replay_policy: ReplayPolicy::Instant, - max_ack_pending: settings.max_in_flight as i64, - max_deliver: settings.max_deliver, + max_ack_pending: policy.max_ack_pending, + max_deliver: policy.max_deliver, ..Default::default() }; @@ -126,38 +132,16 @@ impl JetStreamClient { pub async fn publish_effect_result( &self, - subject: String, result: &EffectResultEnvelope, ) -> Result<(), RunnerError> { + let subject = shared::nats_subject_effect_result( + result.tenant_id.as_str(), + result.effect_name.as_str(), + result.command_id.as_str(), + ); let payload = serde_json::to_vec(result).map_err(|e| RunnerError::DecodeError(e.to_string()))?; - let mut headers = async_nats::HeaderMap::new(); - headers.insert("tenant-id", result.tenant_id.as_str()); - headers.insert("command-id", result.command_id.as_str()); - headers.insert("effect-name", result.effect_name.as_str()); - if let Some(correlation_id) = result.metadata.correlation_id.as_ref() { - headers.insert("x-correlation-id", correlation_id.as_str()); - headers.insert("correlation-id", correlation_id.as_str()); - } - if let Some(trace_id) = result.metadata.trace_id.as_ref() { - headers.insert("trace-id", trace_id.as_str()); - if let Some(traceparent) = shared::traceparent_from_trace_id(trace_id) { - headers.insert("traceparent", traceparent.as_str()); - } - } - if let Some(traceparent) = result - .metadata - .extra - .get("traceparent") - .and_then(|v| v.as_str()) - { - headers.insert("traceparent", traceparent); - if result.metadata.trace_id.is_none() { - if let Some(trace_id) = shared::trace_id_from_traceparent(traceparent) { - headers.insert("trace-id", trace_id); - } - } - } + let headers = build_effect_result_headers(result); self.jetstream .publish_with_headers(subject, headers, payload.into()) @@ -170,43 +154,15 @@ impl JetStreamClient { &self, cmd: &EffectCommandEnvelope, ) -> Result<(), RunnerError> { - let subject = format!( - "tenant.{}.effect.{}.{}", + let subject = shared::nats_subject_effect_command( cmd.tenant_id.as_str(), cmd.effect_name.as_str(), - cmd.command_id.as_str() + cmd.command_id.as_str(), ); let payload = serde_json::to_vec(cmd).map_err(|e| RunnerError::DecodeError(e.to_string()))?; - let mut headers = async_nats::HeaderMap::new(); - headers.insert("Nats-Msg-Id", cmd.command_id.as_str()); - headers.insert("tenant-id", cmd.tenant_id.as_str()); - headers.insert("command-id", cmd.command_id.as_str()); - headers.insert("effect-name", cmd.effect_name.as_str()); - if let Some(correlation_id) = cmd.metadata.correlation_id.as_ref() { - headers.insert("x-correlation-id", correlation_id.as_str()); - headers.insert("correlation-id", correlation_id.as_str()); - } - if let Some(trace_id) = cmd.metadata.trace_id.as_ref() { - headers.insert("trace-id", trace_id.as_str()); - if let Some(traceparent) = shared::traceparent_from_trace_id(trace_id) { - headers.insert("traceparent", traceparent.as_str()); - } - } - if let Some(traceparent) = cmd - .metadata - .extra - .get("traceparent") - .and_then(|v| v.as_str()) - { - headers.insert("traceparent", traceparent); - if cmd.metadata.trace_id.is_none() { - if let Some(trace_id) = shared::trace_id_from_traceparent(traceparent) { - headers.insert("trace-id", trace_id); - } - } - } + let headers = build_effect_command_headers(cmd); self.jetstream .publish_with_headers(subject, headers, payload.into()) @@ -216,6 +172,120 @@ impl JetStreamClient { } } +fn build_effect_command_headers(cmd: &EffectCommandEnvelope) -> async_nats::HeaderMap { + let mut headers = async_nats::HeaderMap::new(); + + let effect_name = cmd.effect_name.as_str().to_string(); + + let ctx = shared::nats_context_headers_required( + cmd.tenant_id.as_str(), + Some(cmd.command_id.as_str()), + cmd.metadata.correlation_id.as_ref().map(|v| v.as_str()), + cmd.metadata + .extra + .get(shared::HEADER_TRACEPARENT) + .and_then(|v| v.as_str()), + cmd.metadata.trace_id.as_ref().map(|v| v.as_str()), + ); + for (k, v) in ctx { + headers.insert(k, v); + } + + headers.insert("command-id", cmd.command_id.as_str().to_string()); + headers.insert("effect-name", effect_name); + + headers +} + +fn build_effect_result_headers(result: &EffectResultEnvelope) -> async_nats::HeaderMap { + let mut headers = async_nats::HeaderMap::new(); + + let effect_name = result.effect_name.as_str().to_string(); + + let ctx = shared::nats_context_headers_required( + result.tenant_id.as_str(), + Some(result.command_id.as_str()), + result.metadata.correlation_id.as_ref().map(|v| v.as_str()), + result + .metadata + .extra + .get(shared::HEADER_TRACEPARENT) + .and_then(|v| v.as_str()), + result.metadata.trace_id.as_ref().map(|v| v.as_str()), + ); + for (k, v) in ctx { + headers.insert(k, v); + } + + headers.insert("command-id", result.command_id.as_str().to_string()); + headers.insert("effect-name", effect_name); + + headers +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{CommandId, EffectName, MessageMetadata, TenantId}; + use chrono::Utc; + + #[test] + fn effect_command_headers_include_required_context() { + let cmd = EffectCommandEnvelope { + tenant_id: TenantId::new("t1"), + command_id: CommandId::new("c1"), + effect_name: EffectName::new("send_email"), + payload: serde_json::json!({"x": 1}), + metadata: MessageMetadata::default(), + }; + + let headers = build_effect_command_headers(&cmd); + assert!(headers.get(shared::NATS_HEADER_TENANT_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_NATS_MSG_ID).is_some()); + assert!(headers.get(shared::HEADER_X_CORRELATION_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_CORRELATION_ID).is_some()); + assert!(headers.get(shared::HEADER_TRACEPARENT).is_some()); + assert!(headers.get(shared::HEADER_TRACE_ID).is_some()); + } + + #[test] + fn effect_result_headers_include_required_context() { + let result = EffectResultEnvelope { + tenant_id: TenantId::new("t1"), + command_id: CommandId::new("c1"), + effect_name: EffectName::new("send_email"), + result_type: crate::types::EffectResultType::Succeeded, + payload: serde_json::json!({"ok": true}), + timestamp: Utc::now(), + metadata: MessageMetadata::default(), + }; + + let headers = build_effect_result_headers(&result); + assert!(headers.get(shared::NATS_HEADER_TENANT_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_NATS_MSG_ID).is_some()); + assert!(headers.get(shared::HEADER_X_CORRELATION_ID).is_some()); + assert!(headers.get(shared::NATS_HEADER_CORRELATION_ID).is_some()); + assert!(headers.get(shared::HEADER_TRACEPARENT).is_some()); + assert!(headers.get(shared::HEADER_TRACE_ID).is_some()); + } + + #[test] + fn stream_config_validation_allows_subject_superset() { + let expected = stream_policy_config("S", vec!["a".to_string(), "b".to_string()]); + let mut actual = expected.clone(); + actual.subjects.push("c".to_string()); + validate_stream_config(&expected, &actual).unwrap(); + } + + #[test] + fn stream_config_validation_rejects_missing_subject() { + let expected = stream_policy_config("S", vec!["a".to_string(), "b".to_string()]); + let mut actual = expected.clone(); + actual.subjects.retain(|s| s != "b"); + assert!(validate_stream_config(&expected, &actual).is_err()); + } +} + async fn try_init_streams( jetstream: &jetstream::Context, settings: &Settings, @@ -263,18 +333,54 @@ async fn ensure_stream( name: &str, subjects: Vec, ) -> Result { - let config = StreamConfig { - name: name.to_string(), - subjects, - max_messages: 10_000_000, - max_bytes: -1, - max_age: std::time::Duration::from_secs(365 * 24 * 60 * 60), - duplicate_window: std::time::Duration::from_secs(120), - ..Default::default() - }; - jetstream - .get_or_create_stream(config) + let expected = stream_policy_config(name, subjects); + let mut stream = jetstream + .get_or_create_stream(expected.clone()) .await + .map_err(|e| StreamInitError::Stream(e.to_string()))?; + + let info = stream + .info() + .await + .map_err(|e| StreamInitError::Stream(e.to_string()))?; + validate_stream_config(&expected, &info.config)?; + Ok(stream) +} + +fn stream_policy_config(name: &str, subjects: Vec) -> StreamConfig { + let policy = shared::stream_policy_defaults(name.to_string(), subjects); + StreamConfig { + name: policy.name, + subjects: policy.subjects, + max_messages: policy.max_messages, + max_bytes: policy.max_bytes, + max_age: policy.max_age, + duplicate_window: policy.duplicate_window, + ..Default::default() + } +} + +fn validate_stream_config( + expected: &StreamConfig, + actual: &StreamConfig, +) -> Result<(), StreamInitError> { + let expected = shared::stream_policy_from_parts( + expected.name.as_str(), + expected.subjects.clone(), + expected.max_messages, + expected.max_bytes, + expected.max_age, + expected.duplicate_window, + ); + let actual = shared::stream_policy_from_parts( + actual.name.as_str(), + actual.subjects.clone(), + actual.max_messages, + actual.max_bytes, + actual.max_age, + actual.duplicate_window, + ); + shared::validate_stream_policy(&expected, &actual) .map_err(|e| StreamInitError::Stream(e.to_string())) } diff --git a/runner/tests/jetstream_integration.rs b/runner/tests/jetstream_integration.rs index dad4fd1..ef997d1 100644 --- a/runner/tests/jetstream_integration.rs +++ b/runner/tests/jetstream_integration.rs @@ -44,9 +44,6 @@ fn jetstream_connects_and_can_publish_effect_result() { metadata: MessageMetadata::default(), }; - runner_js - .publish_effect_result("tenant.t1.effect_result.noop.c1".to_string(), &result) - .await - .unwrap(); + runner_js.publish_effect_result(&result).await.unwrap(); }); } diff --git a/shared/src/lib.rs b/shared/src/lib.rs index 0dea649..970b66c 100644 --- a/shared/src/lib.rs +++ b/shared/src/lib.rs @@ -1,12 +1,112 @@ use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; use std::fmt; use std::str::FromStr; +use std::time::Duration; use uuid::Uuid; pub const HEADER_X_CORRELATION_ID: &str = "x-correlation-id"; +pub const HEADER_X_TENANT_ID: &str = "x-tenant-id"; +pub const HEADER_X_REQUEST_ID: &str = "x-request-id"; pub const HEADER_TRACEPARENT: &str = "traceparent"; pub const HEADER_TRACE_ID: &str = "trace-id"; pub const NATS_HEADER_CORRELATION_ID: &str = "correlation-id"; +pub const NATS_HEADER_TENANT_ID: &str = "tenant-id"; +pub const NATS_HEADER_NATS_MSG_ID: &str = "Nats-Msg-Id"; + +pub const NATS_SUBJECT_AGGREGATE_EVENTS_ALL: &str = "tenant.*.aggregate.*.*"; +pub const NATS_SUBJECT_EFFECT_COMMANDS_ALL: &str = "tenant.*.effect.*.*"; +pub const NATS_SUBJECT_WORKFLOW_COMMANDS_ALL: &str = "tenant.*.workflow.*.*"; +pub const NATS_SUBJECT_EFFECT_RESULTS_ALL: &str = "tenant.*.effect_result.*.*"; +pub const NATS_SUBJECT_WORKFLOW_EVENTS_ALL: &str = "tenant.*.workflow_event.*.*"; + +pub fn nats_subject_aggregate_event( + tenant_id: &str, + aggregate_type: &str, + aggregate_id: &str, +) -> String { + format!("tenant.{tenant_id}.aggregate.{aggregate_type}.{aggregate_id}") +} + +pub fn nats_subject_effect_command(tenant_id: &str, effect_name: &str, command_id: &str) -> String { + format!("tenant.{tenant_id}.effect.{effect_name}.{command_id}") +} + +pub fn nats_subject_effect_result(tenant_id: &str, effect_name: &str, command_id: &str) -> String { + format!("tenant.{tenant_id}.effect_result.{effect_name}.{command_id}") +} + +pub fn nats_subject_workflow_command( + tenant_id: &str, + workflow_name: &str, + command_id: &str, +) -> String { + format!("tenant.{tenant_id}.workflow.{workflow_name}.{command_id}") +} + +pub fn nats_subject_workflow_event(tenant_id: &str, workflow_name: &str, event_id: &str) -> String { + format!("tenant.{tenant_id}.workflow_event.{workflow_name}.{event_id}") +} + +pub fn nats_filter_subject_aggregate_for_tenant(tenant_id: &str) -> String { + format!("tenant.{tenant_id}.aggregate.*.*") +} + +pub fn nats_filter_subject_effect_for_tenant(tenant_id: &str) -> String { + format!("tenant.{tenant_id}.effect.*.*") +} + +pub fn nats_context_headers_required( + tenant_id: &str, + msg_id: Option<&str>, + correlation_id: Option<&str>, + traceparent: Option<&str>, + trace_id: Option<&str>, +) -> BTreeMap { + let mut out = BTreeMap::new(); + + out.insert(NATS_HEADER_TENANT_ID.to_string(), tenant_id.to_string()); + if let Some(msg_id) = msg_id { + let msg_id = msg_id.trim(); + if !msg_id.is_empty() { + out.insert(NATS_HEADER_NATS_MSG_ID.to_string(), msg_id.to_string()); + } + } + + let correlation_id = normalize_correlation_id(correlation_id).to_string(); + out.insert(HEADER_X_CORRELATION_ID.to_string(), correlation_id.clone()); + out.insert(NATS_HEADER_CORRELATION_ID.to_string(), correlation_id); + + let mut traceparent = traceparent + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|tp| normalize_traceparent(Some(tp))) + .or_else(|| { + trace_id + .and_then(|tid| traceparent_from_trace_id(&TraceId::new(tid))) + .and_then(|tp| { + if trace_id_from_traceparent(&tp).is_some() { + Some(tp) + } else { + None + } + }) + }) + .unwrap_or_else(generate_traceparent); + + let trace_id = match trace_id_from_traceparent(&traceparent) { + Some(v) => v.to_string(), + None => { + traceparent = generate_traceparent(); + trace_id_from_traceparent(&traceparent).unwrap().to_string() + } + }; + + out.insert(HEADER_TRACEPARENT.to_string(), traceparent); + out.insert(HEADER_TRACE_ID.to_string(), trace_id); + + out +} #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] pub struct TenantId(String); @@ -121,12 +221,164 @@ impl AsRef for TraceId { } } +pub fn normalize_correlation_id(value: Option<&str>) -> CorrelationId { + value + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(CorrelationId::new) + .unwrap_or_else(CorrelationId::generate) +} + +pub fn generate_traceparent() -> String { + let trace_id = Uuid::new_v4().simple().to_string(); + let span_id = Uuid::new_v4().simple().to_string()[..16].to_string(); + format!("00-{trace_id}-{span_id}-01") +} + +pub fn normalize_traceparent(value: Option<&str>) -> String { + value + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .and_then(|s| { + if trace_id_from_traceparent(s).is_some() { + Some(s.to_string()) + } else { + None + } + }) + .unwrap_or_else(generate_traceparent) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ConsumerPolicy { + pub ack_wait: Duration, + pub max_ack_pending: i64, + pub max_deliver: i64, +} + +pub fn consumer_policy_from_parts( + ack_timeout_ms: u64, + max_in_flight: usize, + max_deliver: i64, +) -> ConsumerPolicy { + ConsumerPolicy { + ack_wait: Duration::from_millis(ack_timeout_ms.max(1)), + max_ack_pending: max_in_flight.max(1) as i64, + max_deliver: max_deliver.max(1), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamPolicy { + pub name: String, + pub subjects: Vec, + pub max_messages: i64, + pub max_bytes: i64, + pub max_age: Duration, + pub duplicate_window: Duration, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamPolicyMismatch(String); + +impl fmt::Display for StreamPolicyMismatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for StreamPolicyMismatch {} + +pub fn stream_policy_defaults(name: impl Into, subjects: Vec) -> StreamPolicy { + StreamPolicy { + name: name.into(), + subjects, + max_messages: 10_000_000, + max_bytes: -1, + max_age: Duration::from_secs(365 * 24 * 60 * 60), + duplicate_window: Duration::from_secs(120), + } +} + +pub fn stream_policy_from_parts( + name: &str, + subjects: Vec, + max_messages: i64, + max_bytes: i64, + max_age: Duration, + duplicate_window: Duration, +) -> StreamPolicy { + StreamPolicy { + name: name.to_string(), + subjects, + max_messages, + max_bytes, + max_age, + duplicate_window, + } +} + +pub fn validate_stream_policy( + expected: &StreamPolicy, + actual: &StreamPolicy, +) -> Result<(), StreamPolicyMismatch> { + if expected.name != actual.name { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: name expected={} actual={}", + expected.name, actual.name + ))); + } + + for subject in expected.subjects.iter() { + if !actual.subjects.iter().any(|s| s == subject) { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: missing subject {}", + subject + ))); + } + } + + fn gte_or_unlimited(actual: i64, expected: i64) -> bool { + actual == -1 || actual >= expected + } + + if !gte_or_unlimited(actual.max_messages, expected.max_messages) { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: max_messages expected>={} actual={}", + expected.max_messages, actual.max_messages + ))); + } + if !gte_or_unlimited(actual.max_bytes, expected.max_bytes) { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: max_bytes expected>={} actual={}", + expected.max_bytes, actual.max_bytes + ))); + } + if actual.max_age < expected.max_age { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: max_age expected>={:?} actual={:?}", + expected.max_age, actual.max_age + ))); + } + if actual.duplicate_window < expected.duplicate_window { + return Err(StreamPolicyMismatch(format!( + "stream config mismatch: duplicate_window expected>={:?} actual={:?}", + expected.duplicate_window, actual.duplicate_window + ))); + } + + Ok(()) +} + pub fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> { let mut parts = traceparent.split('-'); let version = parts.next()?; let trace_id = parts.next()?; let span_id = parts.next()?; let flags = parts.next()?; + if parts.next().is_some() { + return None; + } if version.len() != 2 || trace_id.len() != 32 || span_id.len() != 16 || flags.len() != 2 { return None; } @@ -137,6 +389,9 @@ pub fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> { { return None; } + if is_all_zeros(trace_id) || is_all_zeros(span_id) { + return None; + } Some(trace_id) } @@ -152,6 +407,10 @@ fn is_valid_hex_32(s: &str) -> bool { s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit()) } +fn is_all_zeros(s: &str) -> bool { + s.chars().all(|c| c == '0') +} + #[cfg(test)] mod tests { use super::*; @@ -193,4 +452,96 @@ mod tests { Some("0123456789abcdef0123456789abcdef") ); } + + #[test] + fn trace_id_from_traceparent_rejects_extra_parts() { + let tp = "00-0123456789abcdef0123456789abcdef-1111111111111111-01-extra"; + assert_eq!(trace_id_from_traceparent(tp), None); + } + + #[test] + fn trace_id_from_traceparent_rejects_all_zero_ids() { + let tp = "00-00000000000000000000000000000000-1111111111111111-01"; + assert_eq!(trace_id_from_traceparent(tp), None); + + let tp = "00-0123456789abcdef0123456789abcdef-0000000000000000-01"; + assert_eq!(trace_id_from_traceparent(tp), None); + } + + #[test] + fn normalize_correlation_id_generates_when_missing_or_empty() { + let a = normalize_correlation_id(None); + let b = normalize_correlation_id(Some("")); + assert!(!a.as_str().is_empty()); + assert!(!b.as_str().is_empty()); + assert_ne!(a.as_str(), b.as_str()); + } + + #[test] + fn normalize_traceparent_accepts_valid_else_generates() { + let valid = "00-0123456789abcdef0123456789abcdef-1111111111111111-01"; + assert_eq!(normalize_traceparent(Some(valid)), valid.to_string()); + + let generated = normalize_traceparent(Some("not-a-traceparent")); + assert!(trace_id_from_traceparent(&generated).is_some()); + } + + #[test] + fn nats_subject_builders_are_stable() { + assert_eq!( + nats_subject_aggregate_event("t1", "Account", "a1"), + "tenant.t1.aggregate.Account.a1" + ); + assert_eq!( + nats_subject_effect_command("t1", "send_email", "c1"), + "tenant.t1.effect.send_email.c1" + ); + assert_eq!( + nats_subject_effect_result("t1", "send_email", "c1"), + "tenant.t1.effect_result.send_email.c1" + ); + assert_eq!( + nats_subject_workflow_command("t1", "wf", "c1"), + "tenant.t1.workflow.wf.c1" + ); + assert_eq!( + nats_subject_workflow_event("t1", "wf", "e1"), + "tenant.t1.workflow_event.wf.e1" + ); + assert_eq!( + nats_filter_subject_aggregate_for_tenant("t1"), + "tenant.t1.aggregate.*.*" + ); + assert_eq!( + nats_filter_subject_effect_for_tenant("t1"), + "tenant.t1.effect.*.*" + ); + } + + #[test] + fn nats_context_headers_required_generates_missing_context() { + let headers = nats_context_headers_required("t1", Some("m1"), None, None, None); + assert_eq!(headers.get(NATS_HEADER_TENANT_ID).unwrap(), "t1"); + assert_eq!(headers.get(NATS_HEADER_NATS_MSG_ID).unwrap(), "m1"); + assert!(!headers.get(HEADER_X_CORRELATION_ID).unwrap().is_empty()); + assert!(!headers.get(NATS_HEADER_CORRELATION_ID).unwrap().is_empty()); + assert!(trace_id_from_traceparent(headers.get(HEADER_TRACEPARENT).unwrap()).is_some()); + assert!(headers.get(HEADER_TRACE_ID).unwrap().len() == 32); + } + + #[test] + fn validate_stream_policy_allows_subject_superset() { + let expected = stream_policy_defaults("S", vec!["a".to_string(), "b".to_string()]); + let mut actual = expected.clone(); + actual.subjects.push("c".to_string()); + validate_stream_policy(&expected, &actual).unwrap(); + } + + #[test] + fn validate_stream_policy_rejects_missing_subject() { + let expected = stream_policy_defaults("S", vec!["a".to_string(), "b".to_string()]); + let mut actual = expected.clone(); + actual.subjects.retain(|s| s != "b"); + assert!(validate_stream_policy(&expected, &actual).is_err()); + } }