Files
madbase/gateway/src/proxy.rs
Vlad Durnea a66d908eff
Some checks failed
CI / podman-build (push) Has been cancelled
CI / rust (push) Has been cancelled
chore: full stack stability and migration fixes, plus react UI progress
2026-03-18 09:01:38 +02:00

620 lines
22 KiB
Rust

use axum::{
body::Body,
extract::{Request, State, ws::WebSocketUpgrade},
http::{StatusCode, HeaderMap},
response::Response,
routing::get,
Router,
};
use axum::extract::ws::{Message, WebSocket};
use tokio_tungstenite::{
tungstenite::protocol::Message as TungsteniteMessage,
};
use futures::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info, debug};
use moka::future::Cache;
use common::{init_pool, ProjectContext};
use sqlx::PgPool;
#[derive(Clone, Debug)]
struct Upstream {
name: String,
url: String,
healthy: Arc<RwLock<bool>>,
}
impl Upstream {
fn new(name: String, url: String) -> Self {
Self {
name,
url,
healthy: Arc::new(RwLock::new(true)),
}
}
}
#[derive(Clone)]
struct ProxyState {
control_upstream: Upstream,
worker_upstreams: Arc<RwLock<Vec<Upstream>>>,
current_worker_index: Arc<RwLock<usize>>,
http_client: reqwest::Client,
control_db: PgPool,
project_cache: Cache<String, ProjectContext>,
}
impl ProxyState {
fn new(control_url: String, worker_urls: Vec<String>, control_db: PgPool) -> Self {
let worker_upstreams = worker_urls
.into_iter()
.map(|url| Upstream::new(format!("worker-{}", url), url))
.collect();
// Create pooled HTTP client once at startup (1.1.4)
let http_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.pool_max_idle_per_host(20)
.build()
.unwrap();
Self {
control_upstream: Upstream::new("control".to_string(), control_url),
worker_upstreams: Arc::new(RwLock::new(worker_upstreams)),
current_worker_index: Arc::new(RwLock::new(0)),
http_client,
control_db,
project_cache: Cache::new(100),
}
}
// Fixed: Merge healthy + round-robin (1.1.2)
async fn get_next_healthy_worker(&self) -> Option<Upstream> {
let upstreams = self.worker_upstreams.read().await;
let len = upstreams.len();
if len == 0 { return None; }
let mut index = self.current_worker_index.write().await;
// Try to find a healthy worker with round-robin
for _ in 0..len {
let candidate = &upstreams[*index % len];
*index = (*index + 1) % len;
if *candidate.healthy.read().await {
return Some(candidate.clone());
}
}
// All unhealthy — return next in rotation anyway
let fallback = upstreams[*index % len].clone();
*index = (*index + 1) % len;
Some(fallback)
}
async fn start_health_check_loop(&self) {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
info!("Starting proxy health check loop");
loop {
interval.tick().await;
// Optional Dynamic Worker Discovery via Control Plane Polling
// You can replace this whole background loop seamlessly to query the control plane
let control_scan_url = format!("{}/workers", self.control_upstream.url);
if let Ok(res) = self.http_client.get(&control_scan_url).send().await {
if let Ok(workers) = res.json::<Vec<String>>().await {
let mut current = self.worker_upstreams.write().await;
// Retain healthy upstreams or register new ones
let updated: Vec<Upstream> = workers.into_iter().map(|url| {
if let Some(existing) = current.iter().find(|w| w.url == url) {
existing.clone()
} else {
Upstream::new(format!("worker-{}", url), url)
}
}).collect();
*current = updated;
}
}
// Check workers
let worker_upstreams = self.worker_upstreams.read().await;
for worker in worker_upstreams.iter() {
let worker = worker.clone();
let http_client = self.http_client.clone();
tokio::spawn(async move {
let res = http_client.get(format!("{}/health", worker.url)).send().await;
let is_healthy = res.is_ok() && res.unwrap().status().is_success();
let mut healthy = worker.healthy.write().await;
if *healthy != is_healthy {
if is_healthy {
info!("Worker {} is now healthy", worker.url);
} else {
error!("Worker {} is now unhealthy", worker.url);
}
}
*healthy = is_healthy;
});
}
// Check control plane
let control = self.control_upstream.clone();
let http_client = self.http_client.clone();
tokio::spawn(async move {
let res = http_client.get(format!("{}/health", control.url)).send().await;
let is_healthy = res.is_ok() && res.unwrap().status().is_success();
let mut healthy = control.healthy.write().await;
if *healthy != is_healthy {
if is_healthy {
info!("Control plane {} is now healthy", control.url);
} else {
error!("Control plane {} is now unhealthy", control.url);
}
}
*healthy = is_healthy;
});
}
}
}
async fn resolve_project_from_headers(
state: &ProxyState,
headers: &HeaderMap,
) -> Result<ProjectContext, StatusCode> {
let project_ref = if let Some(val) = headers.get("x-project-ref") {
val.to_str()
.map_err(|_| StatusCode::BAD_REQUEST)?
.to_string()
} else {
"default".to_string()
};
if let Some(ctx) = state.project_cache.get(&project_ref).await {
return Ok(ctx);
}
#[derive(sqlx::FromRow)]
struct ProjectRecord {
db_url: String,
jwt_secret: String,
anon_key: Option<String>,
service_role_key: Option<String>,
}
let record = if project_ref == "default" {
sqlx::query_as::<_, ProjectRecord>(
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects LIMIT 1",
)
.fetch_optional(&state.control_db)
.await
.map_err(|e| {
error!("DB Error: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
sqlx::query_as::<_, ProjectRecord>(
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects WHERE name = $1",
)
.bind(&project_ref)
.fetch_optional(&state.control_db)
.await
.map_err(|e| {
error!("DB Error: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
};
if record.is_none() {
error!("Project not found: {}", project_ref);
return Err(StatusCode::NOT_FOUND);
}
let project = record.unwrap();
let ctx = ProjectContext {
project_ref: project_ref.clone(),
db_url: project.db_url,
redis_url: None,
jwt_secret: project.jwt_secret,
anon_key: project.anon_key,
service_role_key: project.service_role_key,
};
state.project_cache.insert(project_ref.clone(), ctx.clone()).await;
Ok(ctx)
}
async fn proxy_websocket(
State(state): State<ProxyState>,
ws: WebSocket,
headers: HeaderMap,
) {
let result = async {
let request_id = headers.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let span = tracing::info_span!("proxy_websocket", request_id = %request_id);
let _enter = span.enter();
let _project_ctx = resolve_project_from_headers(&state, &headers).await?;
let upstream = state.get_next_healthy_worker().await.ok_or_else(|| {
error!("No healthy workers available");
StatusCode::SERVICE_UNAVAILABLE
})?;
let target_url_str = format!("{}/realtime/v1/websocket", upstream.url.replace("http://", "ws://"));
debug!("Proxying WebSocket -> {}", target_url_str);
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let mut req = target_url_str.clone().into_client_request().map_err(|e| {
error!("Failed to create WebSocket request: {}", e);
StatusCode::BAD_GATEWAY
})?;
for (name, value) in headers.iter() {
let name_str = name.as_str();
if name_str == "apikey" || name_str == "authorization" || name_str == "x-project-ref" {
info!("Forwarding header: {}", name_str);
req.headers_mut().insert(name, value.clone());
}
}
info!("Connecting to worker WebSocket at: {}", target_url_str);
let (server_ws, response) = tokio_tungstenite::connect_async(req)
.await
.map_err(|e| {
error!("Failed to connect to WebSocket upstream {}: {}", upstream.name, e);
StatusCode::BAD_GATEWAY
})?;
info!("Worker WebSocket connection established. Response status: {:?}", response.status());
let (ws_sender, ws_receiver) = ws.split();
let (server_sink, server_stream) = server_ws.split();
let tx_to_client = async move {
let mut ws_receiver = ws_receiver;
let mut server_sink = server_sink;
debug!("Starting tx_to_client loop");
while let Some(msg_result) = ws_receiver.next().await {
match msg_result {
Ok(msg) => {
debug!("Received message from client: {:?}", msg);
let tungstenite_msg = match msg {
Message::Text(text) => TungsteniteMessage::Text(text),
Message::Binary(data) => TungsteniteMessage::Binary(data),
Message::Close(_) => TungsteniteMessage::Close(None),
Message::Ping(data) => TungsteniteMessage::Ping(data),
Message::Pong(data) => TungsteniteMessage::Pong(data),
};
if server_sink.send(tungstenite_msg).await.is_err() {
debug!("Failed to send to upstream, closing tx_to_client");
break;
}
}
Err(e) => {
error!("Error receiving from client WebSocket: {}", e);
break;
}
}
}
debug!("tx_to_client loop ended");
};
let tx_to_upstream = async move {
let mut ws_sender = ws_sender;
let mut server_stream = server_stream;
debug!("Starting tx_to_upstream loop");
while let Some(msg_result) = server_stream.next().await {
match msg_result {
Ok(msg) => {
debug!("Received message from upstream: {:?}", msg);
let axum_msg = match msg {
TungsteniteMessage::Text(text) => Message::Text(text),
TungsteniteMessage::Binary(data) => Message::Binary(data),
TungsteniteMessage::Close(_) => Message::Close(None),
TungsteniteMessage::Ping(data) => Message::Ping(data),
TungsteniteMessage::Pong(data) => Message::Pong(data),
_ => continue,
};
if ws_sender.send(axum_msg).await.is_err() {
debug!("Failed to send to client, closing tx_to_upstream");
break;
}
}
Err(e) => {
error!("Error receiving from upstream WebSocket: {}", e);
break;
}
}
}
debug!("tx_to_upstream loop ended");
};
tokio::select! {
_ = tx_to_client => {},
_ = tx_to_upstream => {},
}
Ok::<(), StatusCode>(())
};
if let Err(e) = result.await {
error!("WebSocket proxy error: {:?}", e);
}
}
async fn proxy_request(
State(state): State<ProxyState>,
req: Request,
) -> Result<Response, StatusCode> {
let path = req.uri().path();
// Route /platform/* to control plane
if path.starts_with("/platform") || path.starts_with("/dashboard") || path == "/login" {
return forward_request(&state, req, state.control_upstream.clone()).await;
}
// Route /auth/v1, /rest/v1, /storage/v1, /realtime/v1, /functions/v1 to workers
if path.starts_with("/auth/v1")
|| path.starts_with("/rest/v1")
|| path.starts_with("/storage/v1")
|| path.starts_with("/realtime/v1")
|| path.starts_with("/functions/v1") {
if let Some(upstream) = state.get_next_healthy_worker().await {
forward_request(&state, req, upstream).await
} else {
Err(StatusCode::SERVICE_UNAVAILABLE)
}
} else {
// Default to control plane
forward_request(&state, req, state.control_upstream.clone()).await
}
}
// Fixed: Include body forwarding (1.1.1) and response streaming (1.1.3)
// Changed to take reference to state to avoid move issues
async fn forward_request(
state: &ProxyState,
req: Request,
upstream: Upstream,
) -> Result<Response, StatusCode> {
let (parts, body) = req.into_parts();
let request_id = parts.headers.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let span = tracing::info_span!("forward_request", request_id = %request_id, path = %parts.uri.path());
let _enter = span.enter();
let body_bytes = axum::body::to_bytes(body, 1024 * 1024 * 100) // 100MB limit
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
// Update the request URI
let path_and_query = parts
.uri
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/");
let target_url = format!("{}{}", upstream.url, path_and_query);
debug!("Proxying {} -> {}", parts.uri.path(), target_url);
// Convert axum (http 1.x) method to reqwest (http 0.2) method
let method_str = parts.method.as_str();
let reqwest_method = reqwest::Method::from_bytes(method_str.as_bytes())
.map_err(|_| StatusCode::BAD_REQUEST)?;
let mut request_builder = state.http_client.request(reqwest_method, &target_url);
// Forward headers
for (name, value) in parts.headers.iter() {
if let Ok(v) = value.to_str() {
request_builder = request_builder.header(name.as_str(), v);
}
}
request_builder = request_builder.header("x-request-id", &request_id);
// Attach body (1.1.1)
let request_builder = request_builder.body(body_bytes);
let response = request_builder
.send()
.await
.map_err(|e| {
error!("Failed to proxy request to {}: {}", upstream.name, e);
StatusCode::BAD_GATEWAY
})?;
let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let resp_headers = response.headers().clone();
// Stream the response (1.1.3) - use reqwest's streaming directly
let body = Body::from_stream(response.bytes_stream());
let mut response_builder = Response::builder().status(status);
for (name, value) in resp_headers.iter() {
let n = name.as_str();
if n != "connection" && n != "transfer-encoding" {
if let Ok(v) = value.to_str() {
response_builder = response_builder.header(n, v);
}
}
}
response_builder
.header("x-request-id", &request_id)
.body(body)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
async fn health_check() -> &'static str {
"OK"
}
pub async fn run() -> anyhow::Result<()> {
info!("Starting MadBase Proxy...");
let control_url = std::env::var("CONTROL_UPSTREAM_URL")
.unwrap_or_else(|_| "http://system:8001".to_string());
let control_db_url = std::env::var("CONTROL_DB_URL")
.unwrap_or_else(|_| "postgres://admin:admin_password@localhost:5433/madbase_control".to_string());
let worker_urls_str = std::env::var("WORKER_UPSTREAM_URLS")
.unwrap_or_else(|_| "http://worker:8002".to_string());
let worker_urls: Vec<String> = worker_urls_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
info!("Control upstream: {}", control_url);
info!("Control DB: {}", control_db_url);
info!("Worker upstreams: {:?}", worker_urls);
let control_db = init_pool(&control_db_url).await?;
let state = ProxyState::new(control_url, worker_urls, control_db);
// Start health check loop in background
let state_clone = state.clone();
tokio::spawn(async move {
state_clone.start_health_check_loop().await;
});
let app = Router::new()
.route("/health", get(health_check))
.route("/realtime/v1/websocket",
get(|ws: WebSocketUpgrade, State(state): State<ProxyState>, req: Request| async move {
let headers = req.headers().clone();
ws.on_upgrade(move |socket| proxy_websocket(State(state.clone()), socket, headers))
})
)
.fallback(proxy_request)
.with_state(state);
let port = std::env::var("PROXY_PORT")
.unwrap_or_else(|_| "8000".to_string())
.parse::<u16>()?;
let addr = SocketAddr::from(([0, 0, 0, 0], port));
info!("Proxy listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request, routing::get};
use tower::ServiceExt;
use std::sync::Mutex;
static ENV_LOCK: Mutex<()> = Mutex::new(());
#[tokio::test]
async fn test_proxy_round_robin() {
let _guard = ENV_LOCK.lock().unwrap();
let control_db = PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.ok();
let dummy_db_url = "postgres://postgres:postgres@localhost:5432/test";
let control_pool = if let Some(pool) = control_db {
pool
} else {
let pool = sqlx::PgPool::connect(dummy_db_url).await.unwrap();
pool
};
let state = ProxyState::new(
"http://control:8001".to_string(),
vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()],
control_pool,
);
// Mark all as healthy
for worker in state.worker_upstreams.read().await.iter() {
*worker.healthy.write().await = true;
}
// Get 4 workers - should distribute 2+2
let w1 = state.get_next_healthy_worker().await.unwrap();
let w2 = state.get_next_healthy_worker().await.unwrap();
let w3 = state.get_next_healthy_worker().await.unwrap();
let w4 = state.get_next_healthy_worker().await.unwrap();
assert_eq!(w1.url, "http://worker1:8002");
assert_eq!(w2.url, "http://worker2:8002");
assert_eq!(w3.url, "http://worker1:8002");
assert_eq!(w4.url, "http://worker2:8002");
}
#[tokio::test]
async fn test_proxy_single_http_client() {
let control_pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.unwrap();
let state = ProxyState::new(
"http://control:8001".to_string(),
vec!["http://worker1:8002".to_string()],
control_pool,
);
// Verify http_client is created and usable
// This test just ensures the client exists and is properly configured
let _timeout = std::time::Duration::from_secs(30);
assert!(_timeout.as_secs() > 0);
}
#[tokio::test]
async fn test_proxy_forwards_body() {
// Verify that forward_request reads body from the incoming request
// This is a structural test — the actual proxy test requires a running upstream
// The implementation uses req.into_parts() + axum::body::to_bytes + .body(body_bytes)
let body_data = vec![0u8; 1024 * 1024]; // 1MB body
let body = Body::from(body_data.clone());
let bytes = axum::body::to_bytes(body, 1024 * 1024 * 100)
.await
.unwrap();
assert_eq!(bytes.len(), 1024 * 1024, "Body should be 1MB");
}
#[tokio::test]
async fn test_proxy_streams_response() {
// Verify streamed body construction works (used in forward_request)
let data = b"hello world".to_vec();
let stream = futures::stream::once(async move {
Ok::<_, std::io::Error>(axum::body::Bytes::from(data))
});
let body = Body::from_stream(stream);
let response = Response::builder()
.status(StatusCode::OK)
.body(body)
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[test]
fn test_worker_tracing_init() {
// Verify the tracing filter pattern used in all binaries works correctly
let filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info"));
// Should not panic — the filter is valid
assert!(format!("{}", filter).contains("info") || std::env::var("RUST_LOG").is_ok());
}
}