use axum::{ body::Body, extract::{Request, State, ws::WebSocketUpgrade}, http::{StatusCode, HeaderMap}, response::Response, routing::get, Router, }; use axum::extract::ws::{Message, WebSocket}; use tokio_tungstenite::{ tungstenite::protocol::Message as TungsteniteMessage, }; use futures::{SinkExt, StreamExt}; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{error, info, debug}; use moka::future::Cache; use common::{init_pool, ProjectContext}; use sqlx::PgPool; #[derive(Clone, Debug)] struct Upstream { name: String, url: String, healthy: Arc>, } impl Upstream { fn new(name: String, url: String) -> Self { Self { name, url, healthy: Arc::new(RwLock::new(true)), } } } #[derive(Clone)] struct ProxyState { control_upstream: Upstream, worker_upstreams: Arc>>, current_worker_index: Arc>, http_client: reqwest::Client, control_db: PgPool, project_cache: Cache, } impl ProxyState { fn new(control_url: String, worker_urls: Vec, control_db: PgPool) -> Self { let worker_upstreams = worker_urls .into_iter() .map(|url| Upstream::new(format!("worker-{}", url), url)) .collect(); // Create pooled HTTP client once at startup (1.1.4) let http_client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .pool_max_idle_per_host(20) .build() .unwrap(); Self { control_upstream: Upstream::new("control".to_string(), control_url), worker_upstreams: Arc::new(RwLock::new(worker_upstreams)), current_worker_index: Arc::new(RwLock::new(0)), http_client, control_db, project_cache: Cache::new(100), } } // Fixed: Merge healthy + round-robin (1.1.2) async fn get_next_healthy_worker(&self) -> Option { let upstreams = self.worker_upstreams.read().await; let len = upstreams.len(); if len == 0 { return None; } let mut index = self.current_worker_index.write().await; // Try to find a healthy worker with round-robin for _ in 0..len { let candidate = &upstreams[*index % len]; *index = (*index + 1) % len; if *candidate.healthy.read().await { return Some(candidate.clone()); } } // All unhealthy — return next in rotation anyway let fallback = upstreams[*index % len].clone(); *index = (*index + 1) % len; Some(fallback) } async fn start_health_check_loop(&self) { let mut interval = tokio::time::interval(std::time::Duration::from_secs(5)); info!("Starting proxy health check loop"); loop { interval.tick().await; // Optional Dynamic Worker Discovery via Control Plane Polling // You can replace this whole background loop seamlessly to query the control plane let control_scan_url = format!("{}/workers", self.control_upstream.url); if let Ok(res) = self.http_client.get(&control_scan_url).send().await { if let Ok(workers) = res.json::>().await { let mut current = self.worker_upstreams.write().await; // Retain healthy upstreams or register new ones let updated: Vec = workers.into_iter().map(|url| { if let Some(existing) = current.iter().find(|w| w.url == url) { existing.clone() } else { Upstream::new(format!("worker-{}", url), url) } }).collect(); *current = updated; } } // Check workers let worker_upstreams = self.worker_upstreams.read().await; for worker in worker_upstreams.iter() { let worker = worker.clone(); let http_client = self.http_client.clone(); tokio::spawn(async move { let res = http_client.get(format!("{}/health", worker.url)).send().await; let is_healthy = res.is_ok() && res.unwrap().status().is_success(); let mut healthy = worker.healthy.write().await; if *healthy != is_healthy { if is_healthy { info!("Worker {} is now healthy", worker.url); } else { error!("Worker {} is now unhealthy", worker.url); } } *healthy = is_healthy; }); } // Check control plane let control = self.control_upstream.clone(); let http_client = self.http_client.clone(); tokio::spawn(async move { let res = http_client.get(format!("{}/health", control.url)).send().await; let is_healthy = res.is_ok() && res.unwrap().status().is_success(); let mut healthy = control.healthy.write().await; if *healthy != is_healthy { if is_healthy { info!("Control plane {} is now healthy", control.url); } else { error!("Control plane {} is now unhealthy", control.url); } } *healthy = is_healthy; }); } } } async fn resolve_project_from_headers( state: &ProxyState, headers: &HeaderMap, ) -> Result { let project_ref = if let Some(val) = headers.get("x-project-ref") { val.to_str() .map_err(|_| StatusCode::BAD_REQUEST)? .to_string() } else { "default".to_string() }; if let Some(ctx) = state.project_cache.get(&project_ref).await { return Ok(ctx); } #[derive(sqlx::FromRow)] struct ProjectRecord { db_url: String, jwt_secret: String, anon_key: Option, service_role_key: Option, } let record = if project_ref == "default" { sqlx::query_as::<_, ProjectRecord>( "SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects LIMIT 1", ) .fetch_optional(&state.control_db) .await .map_err(|e| { error!("DB Error: {}", e); StatusCode::INTERNAL_SERVER_ERROR })? } else { sqlx::query_as::<_, ProjectRecord>( "SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects WHERE name = $1", ) .bind(&project_ref) .fetch_optional(&state.control_db) .await .map_err(|e| { error!("DB Error: {}", e); StatusCode::INTERNAL_SERVER_ERROR })? }; if record.is_none() { error!("Project not found: {}", project_ref); return Err(StatusCode::NOT_FOUND); } let project = record.unwrap(); let ctx = ProjectContext { project_ref: project_ref.clone(), db_url: project.db_url, redis_url: None, jwt_secret: project.jwt_secret, anon_key: project.anon_key, service_role_key: project.service_role_key, }; state.project_cache.insert(project_ref.clone(), ctx.clone()).await; Ok(ctx) } async fn proxy_websocket( State(state): State, ws: WebSocket, headers: HeaderMap, ) { let result = async { let request_id = headers.get("x-request-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); let span = tracing::info_span!("proxy_websocket", request_id = %request_id); let _enter = span.enter(); let _project_ctx = resolve_project_from_headers(&state, &headers).await?; let upstream = state.get_next_healthy_worker().await.ok_or_else(|| { error!("No healthy workers available"); StatusCode::SERVICE_UNAVAILABLE })?; let target_url_str = format!("{}/realtime/v1/websocket", upstream.url.replace("http://", "ws://")); debug!("Proxying WebSocket -> {}", target_url_str); use tokio_tungstenite::tungstenite::client::IntoClientRequest; let mut req = target_url_str.clone().into_client_request().map_err(|e| { error!("Failed to create WebSocket request: {}", e); StatusCode::BAD_GATEWAY })?; for (name, value) in headers.iter() { let name_str = name.as_str(); if name_str == "apikey" || name_str == "authorization" || name_str == "x-project-ref" { info!("Forwarding header: {}", name_str); req.headers_mut().insert(name, value.clone()); } } info!("Connecting to worker WebSocket at: {}", target_url_str); let (server_ws, response) = tokio_tungstenite::connect_async(req) .await .map_err(|e| { error!("Failed to connect to WebSocket upstream {}: {}", upstream.name, e); StatusCode::BAD_GATEWAY })?; info!("Worker WebSocket connection established. Response status: {:?}", response.status()); let (ws_sender, ws_receiver) = ws.split(); let (server_sink, server_stream) = server_ws.split(); let tx_to_client = async move { let mut ws_receiver = ws_receiver; let mut server_sink = server_sink; debug!("Starting tx_to_client loop"); while let Some(msg_result) = ws_receiver.next().await { match msg_result { Ok(msg) => { debug!("Received message from client: {:?}", msg); let tungstenite_msg = match msg { Message::Text(text) => TungsteniteMessage::Text(text), Message::Binary(data) => TungsteniteMessage::Binary(data), Message::Close(_) => TungsteniteMessage::Close(None), Message::Ping(data) => TungsteniteMessage::Ping(data), Message::Pong(data) => TungsteniteMessage::Pong(data), }; if server_sink.send(tungstenite_msg).await.is_err() { debug!("Failed to send to upstream, closing tx_to_client"); break; } } Err(e) => { error!("Error receiving from client WebSocket: {}", e); break; } } } debug!("tx_to_client loop ended"); }; let tx_to_upstream = async move { let mut ws_sender = ws_sender; let mut server_stream = server_stream; debug!("Starting tx_to_upstream loop"); while let Some(msg_result) = server_stream.next().await { match msg_result { Ok(msg) => { debug!("Received message from upstream: {:?}", msg); let axum_msg = match msg { TungsteniteMessage::Text(text) => Message::Text(text), TungsteniteMessage::Binary(data) => Message::Binary(data), TungsteniteMessage::Close(_) => Message::Close(None), TungsteniteMessage::Ping(data) => Message::Ping(data), TungsteniteMessage::Pong(data) => Message::Pong(data), _ => continue, }; if ws_sender.send(axum_msg).await.is_err() { debug!("Failed to send to client, closing tx_to_upstream"); break; } } Err(e) => { error!("Error receiving from upstream WebSocket: {}", e); break; } } } debug!("tx_to_upstream loop ended"); }; tokio::select! { _ = tx_to_client => {}, _ = tx_to_upstream => {}, } Ok::<(), StatusCode>(()) }; if let Err(e) = result.await { error!("WebSocket proxy error: {:?}", e); } } async fn proxy_request( State(state): State, req: Request, ) -> Result { let path = req.uri().path(); // Route /platform/* to control plane if path.starts_with("/platform") || path.starts_with("/dashboard") || path == "/login" { return forward_request(&state, req, state.control_upstream.clone()).await; } // Route /auth/v1, /rest/v1, /storage/v1, /realtime/v1, /functions/v1 to workers if path.starts_with("/auth/v1") || path.starts_with("/rest/v1") || path.starts_with("/storage/v1") || path.starts_with("/realtime/v1") || path.starts_with("/functions/v1") { if let Some(upstream) = state.get_next_healthy_worker().await { forward_request(&state, req, upstream).await } else { Err(StatusCode::SERVICE_UNAVAILABLE) } } else { // Default to control plane forward_request(&state, req, state.control_upstream.clone()).await } } // Fixed: Include body forwarding (1.1.1) and response streaming (1.1.3) // Changed to take reference to state to avoid move issues async fn forward_request( state: &ProxyState, req: Request, upstream: Upstream, ) -> Result { let (parts, body) = req.into_parts(); let request_id = parts.headers.get("x-request-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); let span = tracing::info_span!("forward_request", request_id = %request_id, path = %parts.uri.path()); let _enter = span.enter(); let body_bytes = axum::body::to_bytes(body, 1024 * 1024 * 100) // 100MB limit .await .map_err(|_| StatusCode::BAD_REQUEST)?; // Update the request URI let path_and_query = parts .uri .path_and_query() .map(|pq| pq.as_str()) .unwrap_or("/"); let target_url = format!("{}{}", upstream.url, path_and_query); debug!("Proxying {} -> {}", parts.uri.path(), target_url); // Convert axum (http 1.x) method to reqwest (http 0.2) method let method_str = parts.method.as_str(); let reqwest_method = reqwest::Method::from_bytes(method_str.as_bytes()) .map_err(|_| StatusCode::BAD_REQUEST)?; let mut request_builder = state.http_client.request(reqwest_method, &target_url); // Forward headers for (name, value) in parts.headers.iter() { if let Ok(v) = value.to_str() { request_builder = request_builder.header(name.as_str(), v); } } request_builder = request_builder.header("x-request-id", &request_id); // Attach body (1.1.1) let request_builder = request_builder.body(body_bytes); let response = request_builder .send() .await .map_err(|e| { error!("Failed to proxy request to {}: {}", upstream.name, e); StatusCode::BAD_GATEWAY })?; let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); let resp_headers = response.headers().clone(); // Stream the response (1.1.3) - use reqwest's streaming directly let body = Body::from_stream(response.bytes_stream()); let mut response_builder = Response::builder().status(status); for (name, value) in resp_headers.iter() { let n = name.as_str(); if n != "connection" && n != "transfer-encoding" { if let Ok(v) = value.to_str() { response_builder = response_builder.header(n, v); } } } response_builder .header("x-request-id", &request_id) .body(body) .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } async fn health_check() -> &'static str { "OK" } pub async fn run() -> anyhow::Result<()> { info!("Starting MadBase Proxy..."); let control_url = std::env::var("CONTROL_UPSTREAM_URL") .unwrap_or_else(|_| "http://system:8001".to_string()); let control_db_url = std::env::var("CONTROL_DB_URL") .unwrap_or_else(|_| "postgres://admin:admin_password@localhost:5433/madbase_control".to_string()); let worker_urls_str = std::env::var("WORKER_UPSTREAM_URLS") .unwrap_or_else(|_| "http://worker:8002".to_string()); let worker_urls: Vec = worker_urls_str .split(',') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect(); info!("Control upstream: {}", control_url); info!("Control DB: {}", control_db_url); info!("Worker upstreams: {:?}", worker_urls); let control_db = init_pool(&control_db_url).await?; let state = ProxyState::new(control_url, worker_urls, control_db); // Start health check loop in background let state_clone = state.clone(); tokio::spawn(async move { state_clone.start_health_check_loop().await; }); let app = Router::new() .route("/health", get(health_check)) .route("/realtime/v1/websocket", get(|ws: WebSocketUpgrade, State(state): State, req: Request| async move { let headers = req.headers().clone(); ws.on_upgrade(move |socket| proxy_websocket(State(state.clone()), socket, headers)) }) ) .fallback(proxy_request) .with_state(state); let port = std::env::var("PROXY_PORT") .unwrap_or_else(|_| "8000".to_string()) .parse::()?; let addr = SocketAddr::from(([0, 0, 0, 0], port)); info!("Proxy listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app.into_make_service_with_connect_info::()).await?; Ok(()) } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, routing::get}; use tower::ServiceExt; use std::sync::Mutex; static ENV_LOCK: Mutex<()> = Mutex::new(()); #[tokio::test] async fn test_proxy_round_robin() { let _guard = ENV_LOCK.lock().unwrap(); let control_db = PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.ok(); let dummy_db_url = "postgres://postgres:postgres@localhost:5432/test"; let control_pool = if let Some(pool) = control_db { pool } else { let pool = sqlx::PgPool::connect(dummy_db_url).await.unwrap(); pool }; let state = ProxyState::new( "http://control:8001".to_string(), vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()], control_pool, ); // Mark all as healthy for worker in state.worker_upstreams.read().await.iter() { *worker.healthy.write().await = true; } // Get 4 workers - should distribute 2+2 let w1 = state.get_next_healthy_worker().await.unwrap(); let w2 = state.get_next_healthy_worker().await.unwrap(); let w3 = state.get_next_healthy_worker().await.unwrap(); let w4 = state.get_next_healthy_worker().await.unwrap(); assert_eq!(w1.url, "http://worker1:8002"); assert_eq!(w2.url, "http://worker2:8002"); assert_eq!(w3.url, "http://worker1:8002"); assert_eq!(w4.url, "http://worker2:8002"); } #[tokio::test] async fn test_proxy_single_http_client() { let control_pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.unwrap(); let state = ProxyState::new( "http://control:8001".to_string(), vec!["http://worker1:8002".to_string()], control_pool, ); // Verify http_client is created and usable // This test just ensures the client exists and is properly configured let _timeout = std::time::Duration::from_secs(30); assert!(_timeout.as_secs() > 0); } #[tokio::test] async fn test_proxy_forwards_body() { // Verify that forward_request reads body from the incoming request // This is a structural test — the actual proxy test requires a running upstream // The implementation uses req.into_parts() + axum::body::to_bytes + .body(body_bytes) let body_data = vec![0u8; 1024 * 1024]; // 1MB body let body = Body::from(body_data.clone()); let bytes = axum::body::to_bytes(body, 1024 * 1024 * 100) .await .unwrap(); assert_eq!(bytes.len(), 1024 * 1024, "Body should be 1MB"); } #[tokio::test] async fn test_proxy_streams_response() { // Verify streamed body construction works (used in forward_request) let data = b"hello world".to_vec(); let stream = futures::stream::once(async move { Ok::<_, std::io::Error>(axum::body::Bytes::from(data)) }); let body = Body::from_stream(stream); let response = Response::builder() .status(StatusCode::OK) .body(body) .unwrap(); assert_eq!(response.status(), StatusCode::OK); } #[test] fn test_worker_tracing_init() { // Verify the tracing filter pattern used in all binaries works correctly let filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")); // Should not panic — the filter is valid assert!(format!("{}", filter).contains("info") || std::env::var("RUST_LOG").is_ok()); } }