use axum::{ middleware::{from_fn_with_state}, routing::get, Router, }; use axum_prometheus::PrometheusMetricLayer; use common::{init_pool, Config}; use crate::state::AppState; use crate::middleware; use sqlx::PgPool; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::sync::RwLock; use tower_http::cors::{AllowOrigin, CorsLayer}; use axum::http::{HeaderValue, Method}; use axum::http::header; use tower_http::trace::TraceLayer; fn parse_allowed_origins() -> AllowOrigin { let origins_str = std::env::var("ALLOWED_ORIGINS") .unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000,http://localhost:8001".to_string()); let origins: Vec = origins_str .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); AllowOrigin::list(origins) } async fn wait_for_db(db_url: &str) -> PgPool { loop { match init_pool(db_url).await { Ok(pool) => return pool, Err(e) => { tracing::warn!("Database not ready yet, retrying in 2s: {}", e); tokio::time::sleep(Duration::from_secs(2)).await; } } } } pub async fn run() -> anyhow::Result<()> { let config = Config::new().expect("Failed to load configuration"); tracing::info!("Starting MadBase Worker..."); let pool = wait_for_db(&config.database_url).await; let app_state = AppState { control_db: pool.clone(), tenant_pools: Arc::new(RwLock::new(HashMap::new())), }; let auth_state = auth::AuthState { db: pool.clone(), config: config.clone(), }; let data_state = data_api::handlers::DataState { db: pool.clone(), config: config.clone(), }; let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL") .expect("DEFAULT_TENANT_DB_URL must be set"); let tenant_pool = wait_for_db(&default_tenant_db_url).await; let mut tenant_config = config.clone(); tenant_config.database_url = default_tenant_db_url.clone(); // Realtime Init let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone()); // Replication Listener let repl_config = tenant_config.clone(); let repl_tx = realtime_state.broadcast_tx.clone(); tokio::spawn(async move { if let Err(e) = realtime::replication::start_replication_listener(repl_config, repl_tx).await { tracing::error!("Replication listener failed: {}", e); } }); // Storage Init let storage_router = storage::init(pool.clone(), config.clone()).await; // Functions Init let functions_runtime = Arc::new( functions::runtime::WasmRuntime::new() .expect("Failed to initialize WASM runtime") ); let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new()); let functions_state = functions::FunctionsState { db: pool.clone(), config: config.clone(), runtime: functions_runtime, deno_runtime, }; // Auth Middleware State let auth_middleware_state = auth::AuthMiddlewareState { config: config.clone(), }; // Project Middleware State let project_middleware_state = middleware::ProjectMiddlewareState { control_db: app_state.control_db.clone(), tenant_pools: app_state.tenant_pools.clone(), project_cache: moka::future::Cache::new(100), }; // Construct Worker Routes let tenant_routes = Router::new() .nest("/auth/v1", auth::router().with_state(auth_state)) .nest("/rest/v1", data_api::router().with_state(data_state)) .nest("/realtime/v1", realtime_router) .nest("/storage/v1", storage_router) .nest("/functions/v1", functions::router(functions_state)) .layer(from_fn_with_state( auth_middleware_state, auth::auth_middleware, )) .layer(from_fn_with_state( project_middleware_state.clone(), middleware::inject_tenant_pool, )) .layer(from_fn_with_state( project_middleware_state, middleware::resolve_project, )); let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); let app = Router::new() .route("/health", get(|| async { "OK" })) .route("/metrics", get(|| async move { metric_handle.render() })) .route("/ready", get(|| async { "Ready" })) .nest("/", tenant_routes) .layer( CorsLayer::new() .allow_origin(parse_allowed_origins()) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::PATCH, Method::DELETE, Method::OPTIONS]) .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, axum::http::HeaderName::from_static("apikey")]) .allow_credentials(true), ) .layer(TraceLayer::new_for_http()) .layer(prometheus_layer); let port = std::env::var("WORKER_PORT") .unwrap_or_else(|_| "8002".to_string()) .parse::()?; let addr = SocketAddr::from(([0, 0, 0, 0], port)); tracing::info!("Worker listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app.into_make_service_with_connect_info::()).await?; Ok(()) }