use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::fmt; use std::str::FromStr; use std::time::Duration; use uuid::Uuid; pub const HEADER_X_CORRELATION_ID: &str = "x-correlation-id"; pub const HEADER_X_TENANT_ID: &str = "x-tenant-id"; pub const HEADER_X_REQUEST_ID: &str = "x-request-id"; pub const HEADER_TRACEPARENT: &str = "traceparent"; pub const HEADER_TRACE_ID: &str = "trace-id"; pub const NATS_HEADER_CORRELATION_ID: &str = "correlation-id"; pub const NATS_HEADER_TENANT_ID: &str = "tenant-id"; pub const NATS_HEADER_NATS_MSG_ID: &str = "Nats-Msg-Id"; pub const NATS_SUBJECT_AGGREGATE_EVENTS_ALL: &str = "tenant.*.aggregate.*.*"; pub const NATS_SUBJECT_EFFECT_COMMANDS_ALL: &str = "tenant.*.effect.*.*"; pub const NATS_SUBJECT_WORKFLOW_COMMANDS_ALL: &str = "tenant.*.workflow.*.*"; pub const NATS_SUBJECT_EFFECT_RESULTS_ALL: &str = "tenant.*.effect_result.*.*"; pub const NATS_SUBJECT_WORKFLOW_EVENTS_ALL: &str = "tenant.*.workflow_event.*.*"; pub fn nats_subject_aggregate_event( tenant_id: &str, aggregate_type: &str, aggregate_id: &str, ) -> String { format!("tenant.{tenant_id}.aggregate.{aggregate_type}.{aggregate_id}") } pub fn nats_subject_effect_command(tenant_id: &str, effect_name: &str, command_id: &str) -> String { format!("tenant.{tenant_id}.effect.{effect_name}.{command_id}") } pub fn nats_subject_effect_result(tenant_id: &str, effect_name: &str, command_id: &str) -> String { format!("tenant.{tenant_id}.effect_result.{effect_name}.{command_id}") } pub fn nats_subject_workflow_command( tenant_id: &str, workflow_name: &str, command_id: &str, ) -> String { format!("tenant.{tenant_id}.workflow.{workflow_name}.{command_id}") } pub fn nats_subject_workflow_event(tenant_id: &str, workflow_name: &str, event_id: &str) -> String { format!("tenant.{tenant_id}.workflow_event.{workflow_name}.{event_id}") } pub fn nats_filter_subject_aggregate_for_tenant(tenant_id: &str) -> String { format!("tenant.{tenant_id}.aggregate.*.*") } pub fn nats_filter_subject_effect_for_tenant(tenant_id: &str) -> String { format!("tenant.{tenant_id}.effect.*.*") } pub fn nats_context_headers_required( tenant_id: &str, msg_id: Option<&str>, correlation_id: Option<&str>, traceparent: Option<&str>, trace_id: Option<&str>, ) -> BTreeMap { let mut out = BTreeMap::new(); out.insert(NATS_HEADER_TENANT_ID.to_string(), tenant_id.to_string()); if let Some(msg_id) = msg_id { let msg_id = msg_id.trim(); if !msg_id.is_empty() { out.insert(NATS_HEADER_NATS_MSG_ID.to_string(), msg_id.to_string()); } } let correlation_id = normalize_correlation_id(correlation_id).to_string(); out.insert(HEADER_X_CORRELATION_ID.to_string(), correlation_id.clone()); out.insert(NATS_HEADER_CORRELATION_ID.to_string(), correlation_id); let mut traceparent = traceparent .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|tp| normalize_traceparent(Some(tp))) .or_else(|| { trace_id .and_then(|tid| traceparent_from_trace_id(&TraceId::new(tid))) .and_then(|tp| { if trace_id_from_traceparent(&tp).is_some() { Some(tp) } else { None } }) }) .unwrap_or_else(generate_traceparent); let trace_id = match trace_id_from_traceparent(&traceparent) { Some(v) => v.to_string(), None => { traceparent = generate_traceparent(); trace_id_from_traceparent(&traceparent).unwrap().to_string() } }; out.insert(HEADER_TRACEPARENT.to_string(), traceparent); out.insert(HEADER_TRACE_ID.to_string(), trace_id); out } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)] pub struct TenantId(String); impl TenantId { pub fn new(id: impl Into) -> Self { Self(id.into()) } pub fn is_empty(&self) -> bool { self.0.is_empty() } pub fn as_str(&self) -> &str { &self.0 } } impl fmt::Display for TenantId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl FromStr for TenantId { type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(Self(s.to_string())) } } impl AsRef for TenantId { fn as_ref(&self) -> &str { &self.0 } } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct CorrelationId(String); impl CorrelationId { pub fn new(id: impl Into) -> Self { Self(id.into()) } pub fn generate() -> Self { Self(Uuid::new_v4().to_string()) } pub fn as_str(&self) -> &str { &self.0 } } impl fmt::Display for CorrelationId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl FromStr for CorrelationId { type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(Self(s.to_string())) } } impl AsRef for CorrelationId { fn as_ref(&self) -> &str { &self.0 } } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct TraceId(String); impl TraceId { pub fn new(id: impl Into) -> Self { Self(id.into()) } pub fn as_str(&self) -> &str { &self.0 } pub fn is_valid_hex_32(&self) -> bool { is_valid_hex_32(self.as_str()) } } impl fmt::Display for TraceId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl FromStr for TraceId { type Err = std::convert::Infallible; fn from_str(s: &str) -> Result { Ok(Self(s.to_string())) } } impl AsRef for TraceId { fn as_ref(&self) -> &str { &self.0 } } pub fn normalize_correlation_id(value: Option<&str>) -> CorrelationId { value .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(CorrelationId::new) .unwrap_or_else(CorrelationId::generate) } pub fn generate_traceparent() -> String { let trace_id = Uuid::new_v4().simple().to_string(); let span_id = Uuid::new_v4().simple().to_string()[..16].to_string(); format!("00-{trace_id}-{span_id}-01") } pub fn normalize_traceparent(value: Option<&str>) -> String { value .map(|s| s.trim()) .filter(|s| !s.is_empty()) .and_then(|s| { if trace_id_from_traceparent(s).is_some() { Some(s.to_string()) } else { None } }) .unwrap_or_else(generate_traceparent) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ConsumerPolicy { pub ack_wait: Duration, pub max_ack_pending: i64, pub max_deliver: i64, } pub fn consumer_policy_from_parts( ack_timeout_ms: u64, max_in_flight: usize, max_deliver: i64, ) -> ConsumerPolicy { ConsumerPolicy { ack_wait: Duration::from_millis(ack_timeout_ms.max(1)), max_ack_pending: max_in_flight.max(1) as i64, max_deliver: max_deliver.max(1), } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamPolicy { pub name: String, pub subjects: Vec, pub max_messages: i64, pub max_bytes: i64, pub max_age: Duration, pub duplicate_window: Duration, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct StreamPolicyMismatch(String); impl fmt::Display for StreamPolicyMismatch { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } impl std::error::Error for StreamPolicyMismatch {} pub fn stream_policy_defaults(name: impl Into, subjects: Vec) -> StreamPolicy { StreamPolicy { name: name.into(), subjects, max_messages: 10_000_000, max_bytes: -1, max_age: Duration::from_secs(365 * 24 * 60 * 60), duplicate_window: Duration::from_secs(120), } } pub fn stream_policy_from_parts( name: &str, subjects: Vec, max_messages: i64, max_bytes: i64, max_age: Duration, duplicate_window: Duration, ) -> StreamPolicy { StreamPolicy { name: name.to_string(), subjects, max_messages, max_bytes, max_age, duplicate_window, } } pub fn validate_stream_policy( expected: &StreamPolicy, actual: &StreamPolicy, ) -> Result<(), StreamPolicyMismatch> { if expected.name != actual.name { return Err(StreamPolicyMismatch(format!( "stream config mismatch: name expected={} actual={}", expected.name, actual.name ))); } for subject in expected.subjects.iter() { if !actual.subjects.iter().any(|s| s == subject) { return Err(StreamPolicyMismatch(format!( "stream config mismatch: missing subject {}", subject ))); } } fn gte_or_unlimited(actual: i64, expected: i64) -> bool { actual == -1 || actual >= expected } if !gte_or_unlimited(actual.max_messages, expected.max_messages) { return Err(StreamPolicyMismatch(format!( "stream config mismatch: max_messages expected>={} actual={}", expected.max_messages, actual.max_messages ))); } if !gte_or_unlimited(actual.max_bytes, expected.max_bytes) { return Err(StreamPolicyMismatch(format!( "stream config mismatch: max_bytes expected>={} actual={}", expected.max_bytes, actual.max_bytes ))); } if actual.max_age < expected.max_age { return Err(StreamPolicyMismatch(format!( "stream config mismatch: max_age expected>={:?} actual={:?}", expected.max_age, actual.max_age ))); } if actual.duplicate_window < expected.duplicate_window { return Err(StreamPolicyMismatch(format!( "stream config mismatch: duplicate_window expected>={:?} actual={:?}", expected.duplicate_window, actual.duplicate_window ))); } Ok(()) } pub fn trace_id_from_traceparent(traceparent: &str) -> Option<&str> { let mut parts = traceparent.split('-'); let version = parts.next()?; let trace_id = parts.next()?; let span_id = parts.next()?; let flags = parts.next()?; if parts.next().is_some() { return None; } if version.len() != 2 || trace_id.len() != 32 || span_id.len() != 16 || flags.len() != 2 { return None; } if !trace_id.chars().all(|c| c.is_ascii_hexdigit()) || !span_id.chars().all(|c| c.is_ascii_hexdigit()) || !flags.chars().all(|c| c.is_ascii_hexdigit()) || !version.chars().all(|c| c.is_ascii_hexdigit()) { return None; } if is_all_zeros(trace_id) || is_all_zeros(span_id) { return None; } Some(trace_id) } pub fn traceparent_from_trace_id(trace_id: &TraceId) -> Option { if !trace_id.is_valid_hex_32() { return None; } let span_id = Uuid::new_v4().simple().to_string()[..16].to_string(); Some(format!("00-{}-{span_id}-01", trace_id.as_str())) } fn is_valid_hex_32(s: &str) -> bool { s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit()) } fn is_all_zeros(s: &str) -> bool { s.chars().all(|c| c == '0') } #[cfg(test)] mod tests { use super::*; #[test] fn tenant_id_serialization_roundtrip() { let id = TenantId::new("acme-corp"); let json = serde_json::to_string(&id).unwrap(); let decoded: TenantId = serde_json::from_str(&json).unwrap(); assert_eq!(id, decoded); } #[test] fn tenant_id_default_is_empty() { let id = TenantId::default(); assert!(id.is_empty()); } #[test] fn tenant_id_is_send_sync() { fn assert_send_sync() {} assert_send_sync::(); } #[test] fn correlation_id_roundtrip_is_string() { let id = CorrelationId::new("corr-1"); let json = serde_json::to_string(&id).unwrap(); assert_eq!(json, "\"corr-1\""); let decoded: CorrelationId = serde_json::from_str(&json).unwrap(); assert_eq!(decoded.as_str(), "corr-1"); } #[test] fn trace_id_from_traceparent_parses() { let tp = "00-0123456789abcdef0123456789abcdef-1111111111111111-01"; assert_eq!( trace_id_from_traceparent(tp), Some("0123456789abcdef0123456789abcdef") ); } #[test] fn trace_id_from_traceparent_rejects_extra_parts() { let tp = "00-0123456789abcdef0123456789abcdef-1111111111111111-01-extra"; assert_eq!(trace_id_from_traceparent(tp), None); } #[test] fn trace_id_from_traceparent_rejects_all_zero_ids() { let tp = "00-00000000000000000000000000000000-1111111111111111-01"; assert_eq!(trace_id_from_traceparent(tp), None); let tp = "00-0123456789abcdef0123456789abcdef-0000000000000000-01"; assert_eq!(trace_id_from_traceparent(tp), None); } #[test] fn normalize_correlation_id_generates_when_missing_or_empty() { let a = normalize_correlation_id(None); let b = normalize_correlation_id(Some("")); assert!(!a.as_str().is_empty()); assert!(!b.as_str().is_empty()); assert_ne!(a.as_str(), b.as_str()); } #[test] fn normalize_traceparent_accepts_valid_else_generates() { let valid = "00-0123456789abcdef0123456789abcdef-1111111111111111-01"; assert_eq!(normalize_traceparent(Some(valid)), valid.to_string()); let generated = normalize_traceparent(Some("not-a-traceparent")); assert!(trace_id_from_traceparent(&generated).is_some()); } #[test] fn nats_subject_builders_are_stable() { assert_eq!( nats_subject_aggregate_event("t1", "Account", "a1"), "tenant.t1.aggregate.Account.a1" ); assert_eq!( nats_subject_effect_command("t1", "send_email", "c1"), "tenant.t1.effect.send_email.c1" ); assert_eq!( nats_subject_effect_result("t1", "send_email", "c1"), "tenant.t1.effect_result.send_email.c1" ); assert_eq!( nats_subject_workflow_command("t1", "wf", "c1"), "tenant.t1.workflow.wf.c1" ); assert_eq!( nats_subject_workflow_event("t1", "wf", "e1"), "tenant.t1.workflow_event.wf.e1" ); assert_eq!( nats_filter_subject_aggregate_for_tenant("t1"), "tenant.t1.aggregate.*.*" ); assert_eq!( nats_filter_subject_effect_for_tenant("t1"), "tenant.t1.effect.*.*" ); } #[test] fn nats_context_headers_required_generates_missing_context() { let headers = nats_context_headers_required("t1", Some("m1"), None, None, None); assert_eq!(headers.get(NATS_HEADER_TENANT_ID).unwrap(), "t1"); assert_eq!(headers.get(NATS_HEADER_NATS_MSG_ID).unwrap(), "m1"); assert!(!headers.get(HEADER_X_CORRELATION_ID).unwrap().is_empty()); assert!(!headers.get(NATS_HEADER_CORRELATION_ID).unwrap().is_empty()); assert!(trace_id_from_traceparent(headers.get(HEADER_TRACEPARENT).unwrap()).is_some()); assert!(headers.get(HEADER_TRACE_ID).unwrap().len() == 32); } #[test] fn validate_stream_policy_allows_subject_superset() { let expected = stream_policy_defaults("S", vec!["a".to_string(), "b".to_string()]); let mut actual = expected.clone(); actual.subjects.push("c".to_string()); validate_stream_policy(&expected, &actual).unwrap(); } #[test] fn validate_stream_policy_rejects_missing_subject() { let expected = stream_policy_defaults("S", vec!["a".to_string(), "b".to_string()]); let mut actual = expected.clone(); actual.subjects.retain(|s| s != "b"); assert!(validate_stream_policy(&expected, &actual).is_err()); } }