Some checks failed
CI/CD Pipeline / lint (push) Successful in 3m45s
CI/CD Pipeline / integration-tests (push) Failing after 55s
CI/CD Pipeline / unit-tests (push) Failing after 1m1s
CI/CD Pipeline / e2e-tests (push) Has been skipped
CI/CD Pipeline / build (push) Has been skipped
315 lines
10 KiB
Rust
315 lines
10 KiB
Rust
use axum::{
|
|
extract::{Request, Query},
|
|
middleware::{from_fn, from_fn_with_state, Next},
|
|
response::{Response, IntoResponse},
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use axum::http::StatusCode;
|
|
use axum_prometheus::PrometheusMetricLayer;
|
|
use common::{init_pool, Config};
|
|
use sqlx::PgPool;
|
|
use crate::admin_auth::{admin_auth_middleware, AdminAuthState};
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::time::Duration;
|
|
use tower_http::services::ServeDir;
|
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
|
use axum::http::{HeaderValue, Method};
|
|
use axum::http::header;
|
|
use tower_http::trace::TraceLayer;
|
|
|
|
use axum::Json;
|
|
use serde::Deserialize;
|
|
|
|
#[derive(Deserialize)]
|
|
struct LoginRequest {
|
|
password: String,
|
|
}
|
|
|
|
async fn login_handler(
|
|
axum::extract::State(admin_state): axum::extract::State<AdminAuthState>,
|
|
Json(payload): Json<LoginRequest>,
|
|
) -> impl IntoResponse {
|
|
let expected = std::env::var("ADMIN_PASSWORD")
|
|
.expect("ADMIN_PASSWORD must be set");
|
|
|
|
if payload.password != expected {
|
|
return (
|
|
StatusCode::UNAUTHORIZED,
|
|
[("set-cookie", String::new())],
|
|
serde_json::json!({"error": "Invalid password"}).to_string(),
|
|
).into_response();
|
|
}
|
|
|
|
let session_id = admin_state.create_session().await;
|
|
let cookie = format!(
|
|
"madbase_admin_session={}; HttpOnly; SameSite=Strict; Path=/; Max-Age=86400",
|
|
session_id
|
|
);
|
|
|
|
(
|
|
StatusCode::OK,
|
|
[("set-cookie", cookie)],
|
|
serde_json::json!({"message": "Login successful"}).to_string(),
|
|
).into_response()
|
|
}
|
|
|
|
fn parse_allowed_origins() -> AllowOrigin {
|
|
let origins_str = std::env::var("ALLOWED_ORIGINS")
|
|
.unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000,http://localhost:8001".to_string());
|
|
let origins: Vec<HeaderValue> = origins_str
|
|
.split(',')
|
|
.filter_map(|s| s.trim().parse().ok())
|
|
.collect();
|
|
AllowOrigin::list(origins)
|
|
}
|
|
|
|
async fn logs_proxy_handler(
|
|
Query(params): Query<HashMap<String, String>>,
|
|
) -> impl IntoResponse {
|
|
let client = reqwest::Client::new();
|
|
let loki_url = std::env::var("LOKI_URL")
|
|
.unwrap_or_else(|_| "http://loki:3100".to_string());
|
|
let query_url = format!("{}/loki/api/v1/query_range", loki_url);
|
|
|
|
let resp = client
|
|
.get(&query_url)
|
|
.query(¶ms)
|
|
.send()
|
|
.await;
|
|
|
|
match resp {
|
|
Ok(r) => {
|
|
let status = StatusCode::from_u16(r.status().as_u16())
|
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
let body = r.bytes().await.unwrap_or_default();
|
|
(status, body).into_response()
|
|
},
|
|
Err(e) => {
|
|
tracing::error!("Loki proxy error: {}", e);
|
|
(StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn dashboard_handler() -> axum::response::Html<&'static str> {
|
|
axum::response::Html(include_str!("../../web/admin.html"))
|
|
}
|
|
|
|
async fn wait_for_db(db_url: &str) -> PgPool {
|
|
loop {
|
|
match init_pool(db_url).await {
|
|
Ok(pool) => return pool,
|
|
Err(e) => {
|
|
tracing::warn!("Database not ready yet, retrying in 2s: {}", e);
|
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn log_headers(req: Request, next: Next) -> Response {
|
|
tracing::debug!("Request Headers: {:?}", req.headers());
|
|
next.run(req).await
|
|
}
|
|
|
|
pub async fn run() -> anyhow::Result<()> {
|
|
let config = Config::new().expect("Failed to load configuration");
|
|
|
|
tracing::info!("Starting MadBase Control Plane...");
|
|
|
|
let pool = wait_for_db(&config.database_url).await;
|
|
|
|
sqlx::migrate!("../migrations")
|
|
.run(&pool)
|
|
.await
|
|
.expect("Failed to run migrations");
|
|
|
|
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
|
|
.expect("DEFAULT_TENANT_DB_URL must be set");
|
|
let tenant_pool = wait_for_db(&default_tenant_db_url).await;
|
|
|
|
let control_state = control_plane::ControlPlaneState {
|
|
db: pool.clone(),
|
|
tenant_db: tenant_pool.clone(),
|
|
};
|
|
|
|
let admin_auth_state = AdminAuthState::new();
|
|
|
|
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
|
|
|
|
let platform_router = control_plane::router(control_state)
|
|
.route("/logs", get(logs_proxy_handler))
|
|
.route("/login", axum::routing::post(login_handler).with_state(admin_auth_state.clone()));
|
|
|
|
let app = Router::new()
|
|
.route("/", get(|| async { "MadBase Control Plane" }))
|
|
.route("/health", get(|| async { "OK" }))
|
|
.route("/metrics", get(|| async move { metric_handle.render() }))
|
|
.route("/dashboard", get(dashboard_handler))
|
|
.nest_service("/css", ServeDir::new("web/css"))
|
|
.nest_service("/js", ServeDir::new("web/js"))
|
|
.nest("/platform/v1", platform_router)
|
|
.layer(from_fn_with_state(admin_auth_state, admin_auth_middleware))
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(parse_allowed_origins())
|
|
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
|
|
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::COOKIE])
|
|
.allow_credentials(true),
|
|
)
|
|
.layer(TraceLayer::new_for_http())
|
|
.layer(from_fn(log_headers))
|
|
.layer(prometheus_layer);
|
|
|
|
let port = std::env::var("CONTROL_PORT")
|
|
.unwrap_or_else(|_| "8001".to_string())
|
|
.parse::<u16>()?;
|
|
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], port));
|
|
tracing::info!("Control plane listening on {}", addr);
|
|
|
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
|
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use axum::{body::Body, http::Request, routing::get};
|
|
use tower::ServiceExt;
|
|
use std::sync::Mutex;
|
|
|
|
static ENV_LOCK: Mutex<()> = Mutex::new(());
|
|
|
|
#[tokio::test]
|
|
async fn test_cors_blocks_unknown_origin() {
|
|
let _guard = ENV_LOCK.lock().unwrap();
|
|
unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000") };
|
|
|
|
let app = Router::new()
|
|
.route("/test", get(|| async { "ok" }))
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(parse_allowed_origins())
|
|
.allow_methods([Method::GET])
|
|
.allow_credentials(true),
|
|
);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("OPTIONS")
|
|
.uri("/test")
|
|
.header("Origin", "http://evil.com")
|
|
.header("Access-Control-Request-Method", "GET")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
let acao = response
|
|
.headers()
|
|
.get("access-control-allow-origin")
|
|
.map(|v| v.to_str().unwrap_or(""));
|
|
assert!(acao.is_none() || acao == Some(""), "CORS should not allow http://evil.com");
|
|
|
|
unsafe { std::env::remove_var("ALLOWED_ORIGINS") };
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_cors_allows_configured_origin() {
|
|
let _guard = ENV_LOCK.lock().unwrap();
|
|
unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000,http://mydomain.com") };
|
|
|
|
let app = Router::new()
|
|
.route("/test", get(|| async { "ok" }))
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(parse_allowed_origins())
|
|
.allow_methods([Method::GET])
|
|
.allow_credentials(true),
|
|
);
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("OPTIONS")
|
|
.uri("/test")
|
|
.header("Origin", "http://mydomain.com")
|
|
.header("Access-Control-Request-Method", "GET")
|
|
.body(Body::empty())
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
let acao = response
|
|
.headers()
|
|
.get("access-control-allow-origin")
|
|
.map(|v| v.to_str().unwrap_or(""));
|
|
assert_eq!(acao, Some("http://mydomain.com"));
|
|
|
|
unsafe { std::env::remove_var("ALLOWED_ORIGINS") };
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_login_rejects_wrong_password() {
|
|
let _guard = ENV_LOCK.lock().unwrap();
|
|
unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") };
|
|
|
|
let admin_state = AdminAuthState::new();
|
|
let app = Router::new()
|
|
.route("/login", axum::routing::post(login_handler).with_state(admin_state));
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/login")
|
|
.header("Content-Type", "application/json")
|
|
.body(Body::from(r#"{"password":"wrong"}"#))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
|
unsafe { std::env::remove_var("ADMIN_PASSWORD") };
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_login_accepts_correct_password() {
|
|
let _guard = ENV_LOCK.lock().unwrap();
|
|
unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") };
|
|
|
|
let admin_state = AdminAuthState::new();
|
|
let app = Router::new()
|
|
.route("/login", axum::routing::post(login_handler).with_state(admin_state));
|
|
|
|
let response = app
|
|
.oneshot(
|
|
Request::builder()
|
|
.method("POST")
|
|
.uri("/login")
|
|
.header("Content-Type", "application/json")
|
|
.body(Body::from(r#"{"password":"correct-horse-battery-staple"}"#))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
assert_eq!(response.status(), StatusCode::OK);
|
|
let cookie = response.headers().get("set-cookie").unwrap().to_str().unwrap();
|
|
assert!(cookie.contains("madbase_admin_session="));
|
|
assert!(cookie.contains("HttpOnly"));
|
|
assert!(cookie.contains("SameSite=Strict"));
|
|
|
|
unsafe { std::env::remove_var("ADMIN_PASSWORD") };
|
|
}
|
|
}
|