use axum::{ extract::{Request, Query}, middleware::{from_fn, from_fn_with_state, Next}, response::{Response, IntoResponse}, routing::get, Router, }; use axum::http::StatusCode; use axum_prometheus::PrometheusMetricLayer; use common::{init_pool, Config}; use sqlx::PgPool; use crate::admin_auth::{admin_auth_middleware, AdminAuthState}; use std::collections::HashMap; use std::net::SocketAddr; use std::time::Duration; use tower_http::services::ServeDir; use tower_http::cors::{AllowOrigin, CorsLayer}; use axum::http::{HeaderValue, Method}; use axum::http::header; use tower_http::trace::TraceLayer; use axum::Json; use serde::Deserialize; #[derive(Deserialize)] struct LoginRequest { password: String, } async fn login_handler( axum::extract::State(admin_state): axum::extract::State, Json(payload): Json, ) -> impl IntoResponse { let expected = std::env::var("ADMIN_PASSWORD") .expect("ADMIN_PASSWORD must be set"); if payload.password != expected { return ( StatusCode::UNAUTHORIZED, [("set-cookie", String::new())], serde_json::json!({"error": "Invalid password"}).to_string(), ).into_response(); } let session_id = admin_state.create_session().await; let cookie = format!( "madbase_admin_session={}; HttpOnly; SameSite=Strict; Path=/; Max-Age=86400", session_id ); ( StatusCode::OK, [("set-cookie", cookie)], serde_json::json!({"message": "Login successful"}).to_string(), ).into_response() } fn parse_allowed_origins() -> AllowOrigin { let origins_str = std::env::var("ALLOWED_ORIGINS") .unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000,http://localhost:8001".to_string()); let origins: Vec = origins_str .split(',') .filter_map(|s| s.trim().parse().ok()) .collect(); AllowOrigin::list(origins) } async fn logs_proxy_handler( Query(params): Query>, ) -> impl IntoResponse { let client = reqwest::Client::new(); let loki_url = std::env::var("LOKI_URL") .unwrap_or_else(|_| "http://loki:3100".to_string()); let query_url = format!("{}/loki/api/v1/query_range", loki_url); let resp = client .get(&query_url) .query(¶ms) .send() .await; match resp { Ok(r) => { let status = StatusCode::from_u16(r.status().as_u16()) .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); let body = r.bytes().await.unwrap_or_default(); (status, body).into_response() }, Err(e) => { tracing::error!("Loki proxy error: {}", e); (StatusCode::BAD_GATEWAY, e.to_string()).into_response() } } } async fn dashboard_handler() -> axum::response::Html<&'static str> { axum::response::Html(include_str!("../../web/admin.html")) } async fn wait_for_db(db_url: &str) -> PgPool { loop { match init_pool(db_url).await { Ok(pool) => return pool, Err(e) => { tracing::warn!("Database not ready yet, retrying in 2s: {}", e); tokio::time::sleep(Duration::from_secs(2)).await; } } } } async fn log_headers(req: Request, next: Next) -> Response { tracing::debug!("Request Headers: {:?}", req.headers()); next.run(req).await } pub async fn run() -> anyhow::Result<()> { let config = Config::new().expect("Failed to load configuration"); tracing::info!("Starting MadBase Control Plane..."); let pool = wait_for_db(&config.database_url).await; sqlx::migrate!("../migrations") .run(&pool) .await .expect("Failed to run migrations"); let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL") .expect("DEFAULT_TENANT_DB_URL must be set"); let tenant_pool = wait_for_db(&default_tenant_db_url).await; let control_state = control_plane::ControlPlaneState { db: pool.clone(), tenant_db: tenant_pool.clone(), }; let admin_auth_state = AdminAuthState::new(); let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair(); let platform_router = control_plane::router(control_state) .route("/logs", get(logs_proxy_handler)) .route("/login", axum::routing::post(login_handler).with_state(admin_auth_state.clone())); let app = Router::new() .route("/", get(|| async { "MadBase Control Plane" })) .route("/health", get(|| async { "OK" })) .route("/metrics", get(|| async move { metric_handle.render() })) .route("/dashboard", get(dashboard_handler)) .nest_service("/css", ServeDir::new("web/css")) .nest_service("/js", ServeDir::new("web/js")) .nest("/platform/v1", platform_router) .layer(from_fn_with_state(admin_auth_state, admin_auth_middleware)) .layer( CorsLayer::new() .allow_origin(parse_allowed_origins()) .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS]) .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::COOKIE]) .allow_credentials(true), ) .layer(TraceLayer::new_for_http()) .layer(from_fn(log_headers)) .layer(prometheus_layer); let port = std::env::var("CONTROL_PORT") .unwrap_or_else(|_| "8001".to_string()) .parse::()?; let addr = SocketAddr::from(([0, 0, 0, 0], port)); tracing::info!("Control plane listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app.into_make_service_with_connect_info::()).await?; Ok(()) } #[cfg(test)] mod tests { use super::*; use axum::{body::Body, http::Request, routing::get}; use tower::ServiceExt; use std::sync::Mutex; static ENV_LOCK: Mutex<()> = Mutex::new(()); #[tokio::test] async fn test_cors_blocks_unknown_origin() { let _guard = ENV_LOCK.lock().unwrap(); unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000") }; let app = Router::new() .route("/test", get(|| async { "ok" })) .layer( CorsLayer::new() .allow_origin(parse_allowed_origins()) .allow_methods([Method::GET]) .allow_credentials(true), ); let response = app .oneshot( Request::builder() .method("OPTIONS") .uri("/test") .header("Origin", "http://evil.com") .header("Access-Control-Request-Method", "GET") .body(Body::empty()) .unwrap(), ) .await .unwrap(); let acao = response .headers() .get("access-control-allow-origin") .map(|v| v.to_str().unwrap_or("")); assert!(acao.is_none() || acao == Some(""), "CORS should not allow http://evil.com"); unsafe { std::env::remove_var("ALLOWED_ORIGINS") }; } #[tokio::test] async fn test_cors_allows_configured_origin() { let _guard = ENV_LOCK.lock().unwrap(); unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000,http://mydomain.com") }; let app = Router::new() .route("/test", get(|| async { "ok" })) .layer( CorsLayer::new() .allow_origin(parse_allowed_origins()) .allow_methods([Method::GET]) .allow_credentials(true), ); let response = app .oneshot( Request::builder() .method("OPTIONS") .uri("/test") .header("Origin", "http://mydomain.com") .header("Access-Control-Request-Method", "GET") .body(Body::empty()) .unwrap(), ) .await .unwrap(); let acao = response .headers() .get("access-control-allow-origin") .map(|v| v.to_str().unwrap_or("")); assert_eq!(acao, Some("http://mydomain.com")); unsafe { std::env::remove_var("ALLOWED_ORIGINS") }; } #[tokio::test] async fn test_login_rejects_wrong_password() { let _guard = ENV_LOCK.lock().unwrap(); unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") }; let admin_state = AdminAuthState::new(); let app = Router::new() .route("/login", axum::routing::post(login_handler).with_state(admin_state)); let response = app .oneshot( Request::builder() .method("POST") .uri("/login") .header("Content-Type", "application/json") .body(Body::from(r#"{"password":"wrong"}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::UNAUTHORIZED); unsafe { std::env::remove_var("ADMIN_PASSWORD") }; } #[tokio::test] async fn test_login_accepts_correct_password() { let _guard = ENV_LOCK.lock().unwrap(); unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") }; let admin_state = AdminAuthState::new(); let app = Router::new() .route("/login", axum::routing::post(login_handler).with_state(admin_state)); let response = app .oneshot( Request::builder() .method("POST") .uri("/login") .header("Content-Type", "application/json") .body(Body::from(r#"{"password":"correct-horse-battery-staple"}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let cookie = response.headers().get("set-cookie").unwrap().to_str().unwrap(); assert!(cookie.contains("madbase_admin_session=")); assert!(cookie.contains("HttpOnly")); assert!(cookie.contains("SameSite=Strict")); unsafe { std::env::remove_var("ADMIN_PASSWORD") }; } }