use gateway::middleware; use gateway::state::AppState; use axum::{ extract::{Request, Query}, middleware::{from_fn, from_fn_with_state, Next}, response::{Response, IntoResponse}, routing::get, Router, }; use tower_http::services::{ServeDir, ServeFile}; use axum::http::StatusCode; use axum_prometheus::PrometheusMetricLayer; use common::{init_pool, Config, JwtConfig}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; use tower_governor::{governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer}; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::trace::TraceLayer; use moka::future::Cache; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; fn shared_http_client() -> &'static reqwest::Client { static CLIENT: std::sync::OnceLock = std::sync::OnceLock::new(); CLIENT.get_or_init(|| { reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .pool_max_idle_per_host(10) .build() .unwrap() }) } async fn logs_proxy_handler(Query(params): Query>) -> impl IntoResponse { let loki_url = std::env::var("LOKI_URL") .unwrap_or_else(|_| "http://loki:3100".to_string()); let query_url = format!("{}/loki/api/v1/query_range", loki_url); let resp = shared_http_client() .get(&query_url) .query(¶ms) .send() .await; match resp { Ok(r) => { let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); let body = r.bytes().await.unwrap_or_default(); (status, body).into_response() }, Err(e) => { tracing::error!("Loki proxy error: {}", e); (StatusCode::BAD_GATEWAY, e.to_string()).into_response() } } } async fn log_headers(req: Request, next: Next) -> Response { tracing::debug!("Request Headers: {:?}", req.headers()); next.run(req).await } // Dashboard handler removed in favor of direct SPA serving from /web async fn wait_for_db(db_url: &str) -> sqlx::PgPool { loop { match init_pool(db_url).await { Ok(pool) => return pool, Err(e) => { tracing::warn!("Database not ready yet, retrying in 2s: {}", e); tokio::time::sleep(Duration::from_secs(2)).await; } } } } async fn validate_configuration(config: &Config, pool: &sqlx::PgPool) -> anyhow::Result<()> { tracing::info!("Validating configuration..."); // 1. Validate JWT secret format if config.jwt_secret.len() < 32 { anyhow::bail!( "JWT_SECRET too short ({} chars, minimum 32 required)", config.jwt_secret.len() ); } // 1.1 Validate JWT issuer let jwt_issuer = std::env::var("JWT_ISSUER").unwrap_or_else(|_| "madbase".to_string()); tracing::info!( jwt_issuer = %jwt_issuer, "JWT issuer configured" ); // 2. Validate JWT secret consistency with database let row = sqlx::query_as::<_, (String, String, String, String)>( r#" SELECT name, jwt_secret, anon_key, service_role_key FROM projects WHERE name = 'default' LIMIT 1 "# ) .fetch_optional(pool) .await?; if let Some((name, jwt_secret, anon_key, service_role_key)) = row { if jwt_secret != config.jwt_secret { anyhow::bail!( "JWT_SECRET mismatch between environment and database (project 'default')\n\ Environment: {}...\n\ Database: {}...\n\ Run 'scripts/setup_default_project.sh' to fix this.", &config.jwt_secret[..8], &jwt_secret[..8] ); } // Validate that anon_key and service_role_key are present if anon_key.is_empty() { anyhow::bail!("Project 'default' has empty anon_key"); } if service_role_key.is_empty() { anyhow::bail!("Project 'default' has empty service_role_key"); } tracing::info!( project_name = name, jwt_secret_preview = &jwt_secret[..8], anon_key_present = !anon_key.is_empty(), service_role_key_present = !service_role_key.is_empty(), "Project configuration validated" ); } else { anyhow::bail!( "Default project not found in database. Run 'scripts/setup_default_project.sh' to create it." ); } tracing::info!("Configuration validation successful."); Ok(()) } #[tokio::main] async fn main() -> anyhow::Result<()> { // Load configuration dotenvy::dotenv().ok(); let config = Config::new().expect("Failed to load configuration"); // Initialize tracing let rust_log = std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()); let is_json = std::env::var("LOG_FORMAT").ok().as_deref() == Some("json"); use tracing_subscriber::Layer; let filter = tracing_subscriber::EnvFilter::new(&rust_log).boxed(); let fmt_layer = if is_json { tracing_subscriber::fmt::layer().json().boxed() } else { tracing_subscriber::fmt::layer().boxed() }; let otel_layer = if let Ok(otlp_endpoint) = std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT") { use opentelemetry_otlp::WithExportConfig; let tracer = opentelemetry_otlp::new_pipeline() .tracing() .with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint(otlp_endpoint)) .with_trace_config(opentelemetry_sdk::trace::config().with_resource( opentelemetry_sdk::Resource::new(vec![opentelemetry::KeyValue::new("service.name", "madbase-gateway")]) )) .install_batch(opentelemetry_sdk::runtime::Tokio) .expect("Failed to initialize OTLP tracer"); Some(tracing_opentelemetry::layer().with_tracer(tracer).boxed()) } else { None }; let registry = tracing_subscriber::registry() .with(filter) .with(fmt_layer); if let Some(otel) = otel_layer { registry.with(otel).init(); } else { registry.init(); } tracing::info!("Starting MadBase Gateway v4.1 (Admin UI)..."); // Initialize Database (Control Plane / Main DB) tracing::info!("Connecting to database at {}...", config.database_url); let pool = wait_for_db(&config.database_url).await; tracing::info!("Database connected successfully."); // Run Migrations (Tenant only for Gateway) tracing::info!("Running tenant database migrations..."); sqlx::migrate!("../migrations") .run(&pool) .await .expect("Failed to run tenant migrations"); tracing::info!("Tenant migrations applied successfully."); // Validate Configuration validate_configuration(&config, &pool).await?; let app_state = AppState { control_db: pool.clone(), tenant_pools: Arc::new(RwLock::new(HashMap::new())), }; let session_manager = config.redis_url.as_ref().map(|url| { let cache = common::CacheLayer::new(Some(url.clone()), 86400); auth::SessionManager::new(cache, 86400) }); let auth_state = auth::AuthState { db: pool.clone(), config: config.clone(), session_manager, }; let replica_pool = if let Ok(url) = std::env::var("READ_REPLICA_URL") { tracing::info!("Connecting to read replica at {}...", url); Some(wait_for_db(&url).await) } else { None }; let data_state = data_api::handlers::DataState { db: pool.clone(), replica_pool, config: config.clone(), cache: Arc::new(data_api::schema_cache::SchemaCache::new()), }; // Register DDL invalidation listener let ddl_cache = data_state.cache.clone(); let ddl_pool = pool.clone(); tokio::spawn(async move { let mut listener = match sqlx::postgres::PgListener::connect_with(&ddl_pool).await { Ok(l) => l, Err(e) => { tracing::error!("Failed to connect PgListener: {}", e); return; } }; if let Err(e) = listener.listen("madbase_schema_change").await { tracing::error!("Failed to listen on madbase_schema_change: {}", e); return; } tracing::info!("DDL invalidation listener started."); while let Ok(notification) = listener.recv().await { tracing::info!("Received DDL change notification: {}. Invalidating SchemaCache.", notification.payload()); ddl_cache.invalidate_all().await; } }); // Initialize Tenant Database (for Realtime) let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL") .expect("DEFAULT_TENANT_DB_URL must be set"); tracing::info!("Connecting to default tenant database..."); let tenant_pool = wait_for_db(&default_tenant_db_url).await; tracing::info!("Tenant Database connected successfully."); let control_state = control_plane::ControlPlaneState { db: pool.clone(), tenant_db: tenant_pool.clone(), server_manager: None, }; let mut tenant_config = config.clone(); tenant_config.database_url = default_tenant_db_url.clone(); // Realtime Init let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone()); // Start Replication Listener (for default tenant) let repl_state = realtime_state.clone(); let default_db_url = default_tenant_db_url.clone(); tokio::spawn(async move { if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await { tracing::error!("Replication listener failed: {}", e); } }); // Storage Init let storage_router = storage::init(pool.clone(), config.clone()).await; // Functions Init let functions_runtime = Arc::new(functions::runtime::WasmRuntime::new().expect("Failed to initialize WASM runtime")); let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new()); let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::().unwrap_or(4); let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size)); let functions_state = functions::FunctionsState { db: pool.clone(), config: config.clone(), runtime: functions_runtime, deno_runtime, deno_pool, }; // Auth Middleware State let jwt_config = JwtConfig::from_env()?; let auth_middleware_state = auth::AuthMiddlewareState { config: config.clone(), jwt_config, }; // Project Middleware State let project_middleware_state = middleware::ProjectMiddlewareState { control_db: app_state.control_db.clone(), tenant_pools: moka::future::Cache::builder() .max_capacity(100) .time_to_idle(Duration::from_secs(300)) .build(), project_cache: moka::future::Cache::builder() .max_capacity(100) .time_to_live(Duration::from_secs(60)) .build(), realtime: realtime_state, }; // Construct App // We apply `resolve_project` middleware to /auth, /rest, /storage, /realtime // But NOT /platform (admin) let tenant_routes = Router::new() .nest( "/auth/v1", auth::router() .layer(from_fn_with_state( auth_middleware_state.clone(), auth::auth_middleware, )) .with_state(auth_state), ) .nest( "/rest/v1", data_api::router() .layer(from_fn_with_state( auth_middleware_state.clone(), auth::auth_middleware, )) .with_state(data_state), ) .nest("/realtime/v1", realtime_router) .nest( "/storage/v1", storage_router.layer(from_fn_with_state( auth_middleware_state.clone(), auth::auth_middleware, )), ) .nest( "/functions/v1", functions::router(functions_state).layer(from_fn_with_state( auth_middleware_state.clone(), auth::auth_middleware, )), ) .layer(from_fn_with_state( project_middleware_state.clone(), middleware::inject_tenant_pool, )) .layer(from_fn_with_state( project_middleware_state, middleware::resolve_project, )); // Metrics let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); // Rate Limiting Configuration let governor_conf = Arc::new( GovernorConfigBuilder::default() .per_second(config.rate_limit_per_second) .burst_size(config.rate_limit_per_second as u32 * 2) .key_extractor(SmartIpKeyExtractor) .finish() .unwrap(), ); let app = Router::new() .route("/", get(|| async { "Hello, MadBase!" })) .route("/metrics", get(|| async move { metric_handle.render() })) .nest("/", tenant_routes) // Apply project resolution to these .nest( "/platform/v1", // Admin/Control Plane API (No project resolution needed) control_plane::router(control_state) .route("/logs", get(logs_proxy_handler)), ) .layer(GovernorLayer { config: governor_conf, }) .layer({ let origins_str = std::env::var("ALLOWED_ORIGINS") .unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000".to_string()); let origins: Vec = origins_str .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); CorsLayer::new() .allow_origin(AllowOrigin::list(origins)) .allow_methods([ axum::http::Method::GET, axum::http::Method::POST, axum::http::Method::PUT, axum::http::Method::DELETE, axum::http::Method::OPTIONS, ]) .allow_headers([ axum::http::header::CONTENT_TYPE, axum::http::header::AUTHORIZATION, axum::http::HeaderName::from_static("apikey"), ]) .allow_credentials(true) }) .layer(TraceLayer::new_for_http()) .layer(from_fn(log_headers)) .layer(prometheus_layer) .fallback_service(ServeDir::new("web").fallback(ServeFile::new("web/index.html"))); // Run it let addr = SocketAddr::from(([0, 0, 0, 0], config.port)); tracing::info!("Listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; let shutdown = async { tokio::signal::ctrl_c().await.ok(); tracing::info!("Shutdown signal received, draining connections..."); }; axum::serve(listener, app.into_make_service_with_connect_info::()) .with_graceful_shutdown(shutdown) .await?; tracing::info!("Gateway Server shut down cleanly."); Ok(()) }