chore: full stack stability and migration fixes, plus react UI progress
Some checks failed
CI / podman-build (push) Has been cancelled
CI / rust (push) Has been cancelled

This commit is contained in:
2026-03-18 09:01:38 +02:00
parent 38cab8c246
commit a66d908eff
142 changed files with 12210 additions and 3402 deletions

View File

@@ -26,11 +26,16 @@ tower_governor = "0.4.2"
tower-http = { version = "0.6.8", features = ["cors", "trace", "fs"] }
moka = { version = "0.12.14", features = ["future"] }
reqwest = { version = "0.12", features = ["json", "stream"] }
tokio-tungstenite = "0.21"
futures = { workspace = true }
lazy_static = "1.4"
uuid = { workspace = true }
chrono = { workspace = true }
redis = { workspace = true }
opentelemetry = "0.22"
opentelemetry-otlp = { version = "0.15", features = ["tonic"] }
opentelemetry_sdk = { version = "0.22", features = ["rt-tokio"] }
tracing-opentelemetry = "0.23"
[dev-dependencies]
tower = "0.5"

View File

@@ -20,6 +20,7 @@ pub struct AdminAuthState {
struct SessionData {
_created_at: DateTime<Utc>,
last_accessed: DateTime<Utc>,
csrf_token: String,
}
impl AdminAuthState {
@@ -31,9 +32,11 @@ impl AdminAuthState {
pub async fn create_session(&self) -> String {
let session_id = Uuid::new_v4().to_string();
let csrf_token = Uuid::new_v4().to_string();
let data = SessionData {
_created_at: Utc::now(),
last_accessed: Utc::now(),
csrf_token,
};
self.sessions.write().await.insert(session_id.clone(), data);
@@ -44,6 +47,18 @@ impl AdminAuthState {
session_id
}
pub async fn get_csrf_token(&self, session_id: &str) -> Option<String> {
let sessions = self.sessions.read().await;
sessions.get(session_id).map(|d| d.csrf_token.clone())
}
pub async fn validate_csrf_token(&self, session_id: &str, token: &str) -> bool {
let sessions = self.sessions.read().await;
sessions.get(session_id)
.map(|d| d.csrf_token == token)
.unwrap_or(false)
}
pub async fn validate_session(&self, session_id: &str) -> bool {
let mut sessions = self.sessions.write().await;
@@ -88,8 +103,11 @@ pub async fn admin_auth_middleware(
// 2. Protect ONLY the platform API routes
if path.starts_with("/platform/v1") {
// Allow the login endpoint
if path == "/platform/v1/login" {
// Allow the login, logout, and csrf-token endpoints
if path == "/platform/v1/login"
|| path == "/platform/v1/logout"
|| path == "/platform/v1/csrf-token"
{
return Ok(next.run(req).await);
}

View File

@@ -1,16 +1,17 @@
use axum::{
extract::{Request, Query, State},
extract::{Request, Query, State, Path},
middleware::{from_fn, from_fn_with_state, Next},
response::{Response, IntoResponse},
routing::get,
routing::{get, post, delete},
Router,
};
use axum::http::StatusCode;
use axum::http::header::COOKIE;
use axum_prometheus::PrometheusMetricLayer;
use common::{init_pool, Config};
use sqlx::PgPool;
use crate::admin_auth::{admin_auth_middleware, AdminAuthState};
use control_plane::{ControlPlaneState, CreateProjectRequest, RotateKeyRequest};
use control_plane::ControlPlaneState;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
@@ -42,6 +43,12 @@ struct AppState {
control_plane: ControlPlaneState,
}
impl axum::extract::FromRef<AppState> for ControlPlaneState {
fn from_ref(state: &AppState) -> Self {
state.control_plane.clone()
}
}
#[derive(Deserialize)]
struct LoginRequest {
password: String,
@@ -81,6 +88,225 @@ async fn login_handler(
).into_response()
}
async fn logout_handler(
State(state): State<AppState>,
req: Request,
) -> impl IntoResponse {
// Extract session from cookie and revoke it
let session_id = req.headers()
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|cookies| {
cookies.split(';')
.find_map(|c| c.trim().strip_prefix("madbase_admin_session="))
})
.map(|s| s.to_string());
if let Some(sid) = session_id {
state.admin_auth.revoke_session(&sid).await;
}
let clear_cookie = "madbase_admin_session=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0";
(
StatusCode::OK,
[("set-cookie", clear_cookie.to_string())],
serde_json::json!({"message": "Logged out"}).to_string(),
).into_response()
}
async fn csrf_token_handler(
State(state): State<AppState>,
req: Request,
) -> impl IntoResponse {
let session_id = req.headers()
.get(COOKIE)
.and_then(|h| h.to_str().ok())
.and_then(|cookies| {
cookies.split(';')
.find_map(|c| c.trim().strip_prefix("madbase_admin_session="))
})
.map(|s| s.to_string());
if let Some(sid) = session_id {
if let Some(token) = state.admin_auth.get_csrf_token(&sid).await {
return (StatusCode::OK, serde_json::json!({"token": token}).to_string()).into_response();
}
}
(StatusCode::UNAUTHORIZED, serde_json::json!({"error": "No session"}).to_string()).into_response()
}
async fn admin_config_handler() -> impl IntoResponse {
let grafana_url = std::env::var("MADBASE_GRAFANA_URL")
.unwrap_or_else(|_| "/grafana".to_string());
let version = env!("CARGO_PKG_VERSION");
(StatusCode::OK, serde_json::json!({
"grafana_url": grafana_url,
"version": version
}).to_string()).into_response()
}
// Admin-proxied storage endpoints (browser never touches service_role_key)
fn get_service_key() -> String {
std::env::var("SERVICE_ROLE_KEY").unwrap_or_default()
}
async fn admin_storage_buckets() -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.get(format!("{}/storage/v1/bucket", base))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_storage_list_objects(Path(bucket): Path<String>) -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.post(format!("{}/storage/v1/object/list/{}", base, bucket))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_storage_upload(
Path((bucket, name)): Path<(String, String)>,
req: Request,
) -> impl IntoResponse {
let content_type = req.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/octet-stream")
.to_string();
let body_bytes = axum::body::to_bytes(req.into_body(), 100 * 1024 * 1024).await;
let body_bytes = match body_bytes {
Ok(b) => b,
Err(e) => return (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
};
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.post(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.header("Content-Type", content_type)
.body(body_bytes)
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_storage_delete(
Path((bucket, name)): Path<(String, String)>,
) -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.delete(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_storage_download(
Path((bucket, name)): Path<(String, String)>,
) -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.get(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
// Admin-proxied functions endpoints
async fn admin_functions_list() -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.get(format!("{}/functions/v1", base))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_functions_get(Path(name): Path<String>) -> impl IntoResponse {
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.get(format!("{}/functions/v1/{}", base, name))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
async fn admin_functions_deploy(req: Request) -> impl IntoResponse {
let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await;
let body_bytes = match body_bytes {
Ok(b) => b,
Err(e) => return (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
};
let client = shared_http_client();
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
match client.post(format!("{}/functions/v1", base))
.header("Authorization", format!("Bearer {}", get_service_key()))
.header("x-project-ref", "default")
.header("Content-Type", "application/json")
.body(body_bytes)
.send().await {
Ok(r) => {
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let body = r.bytes().await.unwrap_or_default();
(status, body).into_response()
}
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
}
}
fn parse_allowed_origins() -> AllowOrigin {
let origins_str = std::env::var("ALLOWED_ORIGINS")
.unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000,http://localhost:8001".to_string());
@@ -140,97 +366,7 @@ async fn log_headers(req: Request, next: Next) -> Response {
}
// Wrapper handlers for control_plane routes that use AppState
mod platform_routes {
use super::*;
use control_plane::{list_projects, create_project, delete_project, rotate_keys, get_project_keys, list_users, delete_user};
use axum::{routing::{delete, get}, extract::Path};
use uuid::Uuid;
pub async fn list_projects_wrapper(
State(state): State<AppState>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
list_projects(State(control_state)).await
}
pub async fn create_project_wrapper(
State(state): State<AppState>,
Json(payload): Json<CreateProjectRequest>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
create_project(State(control_state), Json(payload)).await
}
pub async fn delete_project_wrapper(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
delete_project(State(control_state), Path(id)).await
}
pub async fn rotate_keys_wrapper(
State(state): State<AppState>,
Path(id): Path<Uuid>,
Json(payload): Json<RotateKeyRequest>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
rotate_keys(State(control_state), Path(id), Json(payload)).await
}
pub async fn list_users_wrapper(
State(state): State<AppState>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
list_users(State(control_state)).await
}
pub async fn delete_user_wrapper(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
delete_user(State(control_state), Path(id)).await
}
pub async fn get_project_keys_wrapper(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
let control_state = ControlPlaneState {
db: state.control_plane.db.clone(),
tenant_db: state.control_plane.tenant_db.clone(),
};
get_project_keys(State(control_state), Path(id)).await
}
pub fn router() -> Router<AppState> {
Router::new()
.route("/projects", get(list_projects_wrapper).post(create_project_wrapper))
.route("/projects/:id", delete(delete_project_wrapper))
.route("/projects/:id/keys", get(get_project_keys_wrapper).put(rotate_keys_wrapper))
.route("/users", get(list_users_wrapper))
.route("/users/:id", delete(delete_user_wrapper))
}
}
// platform_routes now delegates to the consolidated control_plane::router()
pub async fn run() -> anyhow::Result<()> {
let config = Config::new().expect("Failed to load configuration");
@@ -239,18 +375,29 @@ pub async fn run() -> anyhow::Result<()> {
let pool = wait_for_db(&config.database_url).await;
sqlx::migrate!("../migrations")
tracing::info!("Running control plane migrations...");
sqlx::migrate!("../migrations_control")
.run(&pool)
.await
.expect("Failed to run migrations");
.expect("Failed to run control plane migrations");
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
.expect("DEFAULT_TENANT_DB_URL must be set");
let tenant_pool = wait_for_db(&default_tenant_db_url).await;
tracing::info!("Running tenant migrations...");
sqlx::migrate!("../migrations")
.run(&tenant_pool)
.await
.expect("Failed to run tenant migrations");
// Initialize server manager for infrastructure management
let server_manager = control_plane::init_server_manager(pool.clone()).await;
let control_plane_state = ControlPlaneState {
db: pool.clone(),
tenant_db: tenant_pool.clone(),
server_manager,
};
let admin_auth_state = AdminAuthState::new();
@@ -269,16 +416,33 @@ pub async fn run() -> anyhow::Result<()> {
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("Control plane listening on {}", addr);
// Build the control plane platform router (state already applied → Router<()>)
let platform_router = control_plane::router(app_state.control_plane.clone());
let app = Router::new()
.route("/", get(|| async { "MadBase Control Plane" }))
.route("/health", get(|| async { "OK" }))
.route("/metrics", get(|| async move { metric_handle.render() }))
.route("/dashboard", get(dashboard_handler))
.route("/logs", get(logs_proxy_handler))
.route("/login", axum::routing::post(login_handler))
.route("/login", post(login_handler))
.route("/platform/v1/login", post(login_handler))
.route("/platform/v1/logout", post(logout_handler))
.route("/platform/v1/csrf-token", get(csrf_token_handler))
.route("/platform/v1/admin/config", get(admin_config_handler))
// Admin-proxied storage (no service key in browser)
.route("/platform/v1/storage/buckets", get(admin_storage_buckets))
.route("/platform/v1/storage/buckets/:bucket/objects", post(admin_storage_list_objects))
.route("/platform/v1/storage/upload/:bucket/:name", post(admin_storage_upload))
.route("/platform/v1/storage/:bucket/:name", delete(admin_storage_delete).get(admin_storage_download))
// Admin-proxied functions
.route("/platform/v1/functions", get(admin_functions_list).post(admin_functions_deploy))
.route("/platform/v1/functions/:name", get(admin_functions_get))
.nest_service("/css", ServeDir::new("web/css"))
.nest_service("/js", ServeDir::new("web/js"))
.nest("/platform/v1", platform_routes::router())
.nest_service("/vendor", ServeDir::new("web/vendor"))
.with_state(app_state)
.merge(platform_router)
.layer(from_fn(log_headers))
.layer(prometheus_layer)
.layer(
@@ -288,9 +452,8 @@ pub async fn run() -> anyhow::Result<()> {
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::COOKIE])
.allow_credentials(true),
)
.layer(from_fn_with_state(app_state.admin_auth.clone(), admin_auth_middleware))
.layer(TraceLayer::new_for_http())
.with_state(app_state);
.layer(from_fn_with_state(admin_auth_state.clone(), admin_auth_middleware))
.layer(TraceLayer::new_for_http());
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;

View File

@@ -7,3 +7,12 @@ pub mod proxy;
pub mod rate_limit;
pub use rate_limit::{RateLimiter, RateLimitConfig, RateLimitMiddleware, RateLimitStatus};
/// Runs tenant-specific migrations on the provided pool.
/// This ensures that every tenant database has the required auth, storage,
/// functions, and realtime schemas/tables.
pub async fn run_tenant_migrations(pool: &sqlx::PgPool) -> Result<(), sqlx::migrate::MigrateError> {
sqlx::migrate!("../migrations")
.run(pool)
.await
}

View File

@@ -1,5 +1,5 @@
mod middleware;
mod state;
use gateway::middleware;
use gateway::state::AppState;
use axum::{
extract::{Request, Query},
@@ -8,10 +8,10 @@ use axum::{
routing::get,
Router,
};
use tower_http::services::{ServeDir, ServeFile};
use axum::http::StatusCode;
use axum_prometheus::PrometheusMetricLayer;
use common::{init_pool, Config};
use state::AppState;
use common::{init_pool, Config, JwtConfig};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
@@ -63,9 +63,7 @@ async fn log_headers(req: Request, next: Next) -> Response {
next.run(req).await
}
async fn dashboard_handler() -> axum::response::Html<&'static str> {
axum::response::Html(include_str!("../../web/admin.html"))
}
// Dashboard handler removed in favor of direct SPA serving from /web
async fn wait_for_db(db_url: &str) -> sqlx::PgPool {
loop {
@@ -79,6 +77,73 @@ async fn wait_for_db(db_url: &str) -> sqlx::PgPool {
}
}
async fn validate_configuration(config: &Config, pool: &sqlx::PgPool) -> anyhow::Result<()> {
tracing::info!("Validating configuration...");
// 1. Validate JWT secret format
if config.jwt_secret.len() < 32 {
anyhow::bail!(
"JWT_SECRET too short ({} chars, minimum 32 required)",
config.jwt_secret.len()
);
}
// 1.1 Validate JWT issuer
let jwt_issuer = std::env::var("JWT_ISSUER").unwrap_or_else(|_| "madbase".to_string());
tracing::info!(
jwt_issuer = %jwt_issuer,
"JWT issuer configured"
);
// 2. Validate JWT secret consistency with database
let row = sqlx::query_as::<_, (String, String, String, String)>(
r#"
SELECT name, jwt_secret, anon_key, service_role_key
FROM projects
WHERE name = 'default'
LIMIT 1
"#
)
.fetch_optional(pool)
.await?;
if let Some((name, jwt_secret, anon_key, service_role_key)) = row {
if jwt_secret != config.jwt_secret {
anyhow::bail!(
"JWT_SECRET mismatch between environment and database (project 'default')\n\
Environment: {}...\n\
Database: {}...\n\
Run 'scripts/setup_default_project.sh' to fix this.",
&config.jwt_secret[..8],
&jwt_secret[..8]
);
}
// Validate that anon_key and service_role_key are present
if anon_key.is_empty() {
anyhow::bail!("Project 'default' has empty anon_key");
}
if service_role_key.is_empty() {
anyhow::bail!("Project 'default' has empty service_role_key");
}
tracing::info!(
project_name = name,
jwt_secret_preview = &jwt_secret[..8],
anon_key_present = !anon_key.is_empty(),
service_role_key_present = !service_role_key.is_empty(),
"Project configuration validated"
);
} else {
anyhow::bail!(
"Default project not found in database. Run 'scripts/setup_default_project.sh' to create it."
);
}
tracing::info!("Configuration validation successful.");
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load configuration
@@ -87,17 +152,42 @@ async fn main() -> anyhow::Result<()> {
// Initialize tracing
let rust_log = std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into());
let is_json = std::env::var("LOG_FORMAT").ok().as_deref() == Some("json");
if std::env::var("LOG_FORMAT").ok().as_deref() == Some("json") {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(&rust_log))
.with(tracing_subscriber::fmt::layer().json())
.init();
use tracing_subscriber::Layer;
let filter = tracing_subscriber::EnvFilter::new(&rust_log).boxed();
let fmt_layer = if is_json {
tracing_subscriber::fmt::layer().json().boxed()
} else {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(&rust_log))
.with(tracing_subscriber::fmt::layer())
.init();
tracing_subscriber::fmt::layer().boxed()
};
let otel_layer = if let Ok(otlp_endpoint) = std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT") {
use opentelemetry_otlp::WithExportConfig;
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint(otlp_endpoint))
.with_trace_config(opentelemetry_sdk::trace::config().with_resource(
opentelemetry_sdk::Resource::new(vec![opentelemetry::KeyValue::new("service.name", "madbase-gateway")])
))
.install_batch(opentelemetry_sdk::runtime::Tokio)
.expect("Failed to initialize OTLP tracer");
Some(tracing_opentelemetry::layer().with_tracer(tracer).boxed())
} else {
None
};
let registry = tracing_subscriber::registry()
.with(filter)
.with(fmt_layer);
if let Some(otel) = otel_layer {
registry.with(otel).init();
} else {
registry.init();
}
tracing::info!("Starting MadBase Gateway v4.1 (Admin UI)...");
@@ -107,13 +197,16 @@ async fn main() -> anyhow::Result<()> {
let pool = wait_for_db(&config.database_url).await;
tracing::info!("Database connected successfully.");
// Run Migrations
tracing::info!("Running database migrations...");
// Run Migrations (Tenant only for Gateway)
tracing::info!("Running tenant database migrations...");
sqlx::migrate!("../migrations")
.run(&pool)
.await
.expect("Failed to run migrations");
tracing::info!("Migrations applied successfully.");
.expect("Failed to run tenant migrations");
tracing::info!("Tenant migrations applied successfully.");
// Validate Configuration
validate_configuration(&config, &pool).await?;
let app_state = AppState {
control_db: pool.clone(),
@@ -131,11 +224,42 @@ async fn main() -> anyhow::Result<()> {
session_manager,
};
let replica_pool = if let Ok(url) = std::env::var("READ_REPLICA_URL") {
tracing::info!("Connecting to read replica at {}...", url);
Some(wait_for_db(&url).await)
} else {
None
};
let data_state = data_api::handlers::DataState {
db: pool.clone(),
replica_pool,
config: config.clone(),
cache: Arc::new(data_api::schema_cache::SchemaCache::new()),
};
// Register DDL invalidation listener
let ddl_cache = data_state.cache.clone();
let ddl_pool = pool.clone();
tokio::spawn(async move {
let mut listener = match sqlx::postgres::PgListener::connect_with(&ddl_pool).await {
Ok(l) => l,
Err(e) => {
tracing::error!("Failed to connect PgListener: {}", e);
return;
}
};
if let Err(e) = listener.listen("madbase_schema_change").await {
tracing::error!("Failed to listen on madbase_schema_change: {}", e);
return;
}
tracing::info!("DDL invalidation listener started.");
while let Ok(notification) = listener.recv().await {
tracing::info!("Received DDL change notification: {}. Invalidating SchemaCache.", notification.payload());
ddl_cache.invalidate_all().await;
}
});
// Initialize Tenant Database (for Realtime)
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
.expect("DEFAULT_TENANT_DB_URL must be set");
@@ -146,19 +270,20 @@ async fn main() -> anyhow::Result<()> {
let control_state = control_plane::ControlPlaneState {
db: pool.clone(),
tenant_db: tenant_pool.clone(),
server_manager: None,
};
let mut tenant_config = config.clone();
tenant_config.database_url = default_tenant_db_url;
tenant_config.database_url = default_tenant_db_url.clone();
// Realtime Init
let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone());
// Start Replication Listener
let repl_config = tenant_config.clone();
let repl_tx = realtime_state.broadcast_tx.clone();
// Start Replication Listener (for default tenant)
let repl_state = realtime_state.clone();
let default_db_url = default_tenant_db_url.clone();
tokio::spawn(async move {
if let Err(e) = realtime::replication::start_replication_listener(repl_config, repl_tx).await {
if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await {
tracing::error!("Replication listener failed: {}", e);
}
});
@@ -169,23 +294,35 @@ async fn main() -> anyhow::Result<()> {
// Functions Init
let functions_runtime = Arc::new(functions::runtime::WasmRuntime::new().expect("Failed to initialize WASM runtime"));
let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new());
let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::<usize>().unwrap_or(4);
let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size));
let functions_state = functions::FunctionsState {
db: pool.clone(),
config: config.clone(),
runtime: functions_runtime,
deno_runtime,
deno_pool,
};
// Auth Middleware State
let jwt_config = JwtConfig::from_env()?;
let auth_middleware_state = auth::AuthMiddlewareState {
config: config.clone(),
jwt_config,
};
// Project Middleware State
let project_middleware_state = middleware::ProjectMiddlewareState {
control_db: app_state.control_db.clone(),
tenant_pools: app_state.tenant_pools.clone(),
project_cache: Cache::new(100),
tenant_pools: moka::future::Cache::builder()
.max_capacity(100)
.time_to_idle(Duration::from_secs(300))
.build(),
project_cache: moka::future::Cache::builder()
.max_capacity(100)
.time_to_live(Duration::from_secs(60))
.build(),
realtime: realtime_state,
};
// Construct App
@@ -247,11 +384,9 @@ async fn main() -> anyhow::Result<()> {
.finish()
.unwrap(),
);
let app = Router::new()
.route("/", get(|| async { "Hello, MadBase!" }))
.route("/metrics", get(|| async move { metric_handle.render() }))
.route("/dashboard", get(dashboard_handler))
.nest("/", tenant_routes) // Apply project resolution to these
.nest(
"/platform/v1", // Admin/Control Plane API (No project resolution needed)
@@ -286,14 +421,24 @@ async fn main() -> anyhow::Result<()> {
})
.layer(TraceLayer::new_for_http())
.layer(from_fn(log_headers))
.layer(prometheus_layer);
.layer(prometheus_layer)
.fallback_service(ServeDir::new("web").fallback(ServeFile::new("web/index.html")));
// Run it
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
tracing::info!("Listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
let shutdown = async {
tokio::signal::ctrl_c().await.ok();
tracing::info!("Shutdown signal received, draining connections...");
};
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(shutdown)
.await?;
tracing::info!("Gateway Server shut down cleanly.");
Ok(())
}

View File

@@ -8,16 +8,14 @@ use common::init_pool;
use common::ProjectContext;
use moka::future::Cache;
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::warn;
#[derive(Clone)]
pub struct ProjectMiddlewareState {
pub control_db: PgPool,
pub tenant_pools: Arc<RwLock<HashMap<String, PgPool>>>,
pub tenant_pools: Cache<String, PgPool>,
pub project_cache: Cache<String, ProjectContext>,
pub realtime: realtime::RealtimeState,
}
pub async fn resolve_project(
@@ -112,11 +110,11 @@ pub async fn inject_tenant_pool(
let db_url = project_ctx.db_url.clone();
let existing = { state.tenant_pools.read().await.get(&db_url).cloned() };
let pool = if let Some(p) = existing {
// Check bound cache first
let pool = if let Some(p) = state.tenant_pools.get(&db_url).await {
p
} else {
// Init + Insert
let new_pool = init_pool(&db_url)
.await
.map_err(|e| {
@@ -124,9 +122,24 @@ pub async fn inject_tenant_pool(
StatusCode::INTERNAL_SERVER_ERROR
})?;
let mut w = state.tenant_pools.write().await;
let entry = w.entry(db_url).or_insert_with(|| new_pool.clone());
entry.clone()
// Ensure the tenant database is migrated
if let Err(e) = crate::run_tenant_migrations(&new_pool).await {
warn!("Failed to run tenant migrations for {}: {}", db_url, e);
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
// Start replication listener for the tenant
if let Err(e) = realtime::replication::start_replication_listener(
project_ctx.project_ref.clone(),
db_url.clone(),
state.realtime.clone(),
).await {
warn!("Failed to start replication listener for {}: {}", project_ctx.project_ref, e);
}
state.tenant_pools.insert(db_url, new_pool.clone()).await;
new_pool
};
req.extensions_mut().insert(pool);

View File

@@ -1,15 +1,23 @@
use axum::{
body::Body,
extract::{Request, State},
http::StatusCode,
extract::{Request, State, ws::WebSocketUpgrade},
http::{StatusCode, HeaderMap},
response::Response,
routing::get,
Router,
};
use axum::extract::ws::{Message, WebSocket};
use tokio_tungstenite::{
tungstenite::protocol::Message as TungsteniteMessage,
};
use futures::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info, debug};
use moka::future::Cache;
use common::{init_pool, ProjectContext};
use sqlx::PgPool;
#[derive(Clone, Debug)]
struct Upstream {
@@ -34,10 +42,12 @@ struct ProxyState {
worker_upstreams: Arc<RwLock<Vec<Upstream>>>,
current_worker_index: Arc<RwLock<usize>>,
http_client: reqwest::Client,
control_db: PgPool,
project_cache: Cache<String, ProjectContext>,
}
impl ProxyState {
fn new(control_url: String, worker_urls: Vec<String>) -> Self {
fn new(control_url: String, worker_urls: Vec<String>, control_db: PgPool) -> Self {
let worker_upstreams = worker_urls
.into_iter()
.map(|url| Upstream::new(format!("worker-{}", url), url))
@@ -55,6 +65,8 @@ impl ProxyState {
worker_upstreams: Arc::new(RwLock::new(worker_upstreams)),
current_worker_index: Arc::new(RwLock::new(0)),
http_client,
control_db,
project_cache: Cache::new(100),
}
}
@@ -88,6 +100,24 @@ impl ProxyState {
loop {
interval.tick().await;
// Optional Dynamic Worker Discovery via Control Plane Polling
// You can replace this whole background loop seamlessly to query the control plane
let control_scan_url = format!("{}/workers", self.control_upstream.url);
if let Ok(res) = self.http_client.get(&control_scan_url).send().await {
if let Ok(workers) = res.json::<Vec<String>>().await {
let mut current = self.worker_upstreams.write().await;
// Retain healthy upstreams or register new ones
let updated: Vec<Upstream> = workers.into_iter().map(|url| {
if let Some(existing) = current.iter().find(|w| w.url == url) {
existing.clone()
} else {
Upstream::new(format!("worker-{}", url), url)
}
}).collect();
*current = updated;
}
}
// Check workers
let worker_upstreams = self.worker_upstreams.read().await;
for worker in worker_upstreams.iter() {
@@ -130,6 +160,197 @@ impl ProxyState {
}
}
async fn resolve_project_from_headers(
state: &ProxyState,
headers: &HeaderMap,
) -> Result<ProjectContext, StatusCode> {
let project_ref = if let Some(val) = headers.get("x-project-ref") {
val.to_str()
.map_err(|_| StatusCode::BAD_REQUEST)?
.to_string()
} else {
"default".to_string()
};
if let Some(ctx) = state.project_cache.get(&project_ref).await {
return Ok(ctx);
}
#[derive(sqlx::FromRow)]
struct ProjectRecord {
db_url: String,
jwt_secret: String,
anon_key: Option<String>,
service_role_key: Option<String>,
}
let record = if project_ref == "default" {
sqlx::query_as::<_, ProjectRecord>(
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects LIMIT 1",
)
.fetch_optional(&state.control_db)
.await
.map_err(|e| {
error!("DB Error: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
sqlx::query_as::<_, ProjectRecord>(
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects WHERE name = $1",
)
.bind(&project_ref)
.fetch_optional(&state.control_db)
.await
.map_err(|e| {
error!("DB Error: {}", e);
StatusCode::INTERNAL_SERVER_ERROR
})?
};
if record.is_none() {
error!("Project not found: {}", project_ref);
return Err(StatusCode::NOT_FOUND);
}
let project = record.unwrap();
let ctx = ProjectContext {
project_ref: project_ref.clone(),
db_url: project.db_url,
redis_url: None,
jwt_secret: project.jwt_secret,
anon_key: project.anon_key,
service_role_key: project.service_role_key,
};
state.project_cache.insert(project_ref.clone(), ctx.clone()).await;
Ok(ctx)
}
async fn proxy_websocket(
State(state): State<ProxyState>,
ws: WebSocket,
headers: HeaderMap,
) {
let result = async {
let request_id = headers.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let span = tracing::info_span!("proxy_websocket", request_id = %request_id);
let _enter = span.enter();
let _project_ctx = resolve_project_from_headers(&state, &headers).await?;
let upstream = state.get_next_healthy_worker().await.ok_or_else(|| {
error!("No healthy workers available");
StatusCode::SERVICE_UNAVAILABLE
})?;
let target_url_str = format!("{}/realtime/v1/websocket", upstream.url.replace("http://", "ws://"));
debug!("Proxying WebSocket -> {}", target_url_str);
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
let mut req = target_url_str.clone().into_client_request().map_err(|e| {
error!("Failed to create WebSocket request: {}", e);
StatusCode::BAD_GATEWAY
})?;
for (name, value) in headers.iter() {
let name_str = name.as_str();
if name_str == "apikey" || name_str == "authorization" || name_str == "x-project-ref" {
info!("Forwarding header: {}", name_str);
req.headers_mut().insert(name, value.clone());
}
}
info!("Connecting to worker WebSocket at: {}", target_url_str);
let (server_ws, response) = tokio_tungstenite::connect_async(req)
.await
.map_err(|e| {
error!("Failed to connect to WebSocket upstream {}: {}", upstream.name, e);
StatusCode::BAD_GATEWAY
})?;
info!("Worker WebSocket connection established. Response status: {:?}", response.status());
let (ws_sender, ws_receiver) = ws.split();
let (server_sink, server_stream) = server_ws.split();
let tx_to_client = async move {
let mut ws_receiver = ws_receiver;
let mut server_sink = server_sink;
debug!("Starting tx_to_client loop");
while let Some(msg_result) = ws_receiver.next().await {
match msg_result {
Ok(msg) => {
debug!("Received message from client: {:?}", msg);
let tungstenite_msg = match msg {
Message::Text(text) => TungsteniteMessage::Text(text),
Message::Binary(data) => TungsteniteMessage::Binary(data),
Message::Close(_) => TungsteniteMessage::Close(None),
Message::Ping(data) => TungsteniteMessage::Ping(data),
Message::Pong(data) => TungsteniteMessage::Pong(data),
};
if server_sink.send(tungstenite_msg).await.is_err() {
debug!("Failed to send to upstream, closing tx_to_client");
break;
}
}
Err(e) => {
error!("Error receiving from client WebSocket: {}", e);
break;
}
}
}
debug!("tx_to_client loop ended");
};
let tx_to_upstream = async move {
let mut ws_sender = ws_sender;
let mut server_stream = server_stream;
debug!("Starting tx_to_upstream loop");
while let Some(msg_result) = server_stream.next().await {
match msg_result {
Ok(msg) => {
debug!("Received message from upstream: {:?}", msg);
let axum_msg = match msg {
TungsteniteMessage::Text(text) => Message::Text(text),
TungsteniteMessage::Binary(data) => Message::Binary(data),
TungsteniteMessage::Close(_) => Message::Close(None),
TungsteniteMessage::Ping(data) => Message::Ping(data),
TungsteniteMessage::Pong(data) => Message::Pong(data),
_ => continue,
};
if ws_sender.send(axum_msg).await.is_err() {
debug!("Failed to send to client, closing tx_to_upstream");
break;
}
}
Err(e) => {
error!("Error receiving from upstream WebSocket: {}", e);
break;
}
}
}
debug!("tx_to_upstream loop ended");
};
tokio::select! {
_ = tx_to_client => {},
_ = tx_to_upstream => {},
}
Ok::<(), StatusCode>(())
};
if let Err(e) = result.await {
error!("WebSocket proxy error: {:?}", e);
}
}
async fn proxy_request(
State(state): State<ProxyState>,
req: Request,
@@ -166,8 +387,15 @@ async fn forward_request(
req: Request,
upstream: Upstream,
) -> Result<Response, StatusCode> {
// Extract body before consuming the request (1.1.1)
let (parts, body) = req.into_parts();
let request_id = parts.headers.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let span = tracing::info_span!("forward_request", request_id = %request_id, path = %parts.uri.path());
let _enter = span.enter();
let body_bytes = axum::body::to_bytes(body, 1024 * 1024 * 100) // 100MB limit
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
@@ -196,6 +424,7 @@ async fn forward_request(
request_builder = request_builder.header(name.as_str(), v);
}
}
request_builder = request_builder.header("x-request-id", &request_id);
// Attach body (1.1.1)
let request_builder = request_builder.body(body_bytes);
@@ -226,6 +455,7 @@ async fn forward_request(
}
response_builder
.header("x-request-id", &request_id)
.body(body)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
}
@@ -238,10 +468,13 @@ pub async fn run() -> anyhow::Result<()> {
info!("Starting MadBase Proxy...");
let control_url = std::env::var("CONTROL_UPSTREAM_URL")
.unwrap_or_else(|_| "http://control:8001".to_string());
.unwrap_or_else(|_| "http://system:8001".to_string());
let control_db_url = std::env::var("CONTROL_DB_URL")
.unwrap_or_else(|_| "postgres://admin:admin_password@localhost:5433/madbase_control".to_string());
let worker_urls_str = std::env::var("WORKER_UPSTREAM_URLS")
.unwrap_or_else(|_| "http://worker1:8002".to_string());
.unwrap_or_else(|_| "http://worker:8002".to_string());
let worker_urls: Vec<String> = worker_urls_str
.split(',')
@@ -250,9 +483,12 @@ pub async fn run() -> anyhow::Result<()> {
.collect();
info!("Control upstream: {}", control_url);
info!("Control DB: {}", control_db_url);
info!("Worker upstreams: {:?}", worker_urls);
let state = ProxyState::new(control_url, worker_urls);
let control_db = init_pool(&control_db_url).await?;
let state = ProxyState::new(control_url, worker_urls, control_db);
// Start health check loop in background
let state_clone = state.clone();
@@ -262,6 +498,12 @@ pub async fn run() -> anyhow::Result<()> {
let app = Router::new()
.route("/health", get(health_check))
.route("/realtime/v1/websocket",
get(|ws: WebSocketUpgrade, State(state): State<ProxyState>, req: Request| async move {
let headers = req.headers().clone();
ws.on_upgrade(move |socket| proxy_websocket(State(state.clone()), socket, headers))
})
)
.fallback(proxy_request)
.with_state(state);
@@ -291,9 +533,19 @@ mod tests {
async fn test_proxy_round_robin() {
let _guard = ENV_LOCK.lock().unwrap();
let control_db = PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.ok();
let dummy_db_url = "postgres://postgres:postgres@localhost:5432/test";
let control_pool = if let Some(pool) = control_db {
pool
} else {
let pool = sqlx::PgPool::connect(dummy_db_url).await.unwrap();
pool
};
let state = ProxyState::new(
"http://control:8001".to_string(),
vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()]
vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()],
control_pool,
);
// Mark all as healthy
@@ -315,9 +567,11 @@ mod tests {
#[tokio::test]
async fn test_proxy_single_http_client() {
let control_pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.unwrap();
let state = ProxyState::new(
"http://control:8001".to_string(),
vec!["http://worker1:8002".to_string()]
vec!["http://worker1:8002".to_string()],
control_pool,
);
// Verify http_client is created and usable

View File

@@ -4,7 +4,7 @@ use axum::{
Router,
};
use axum_prometheus::PrometheusMetricLayer;
use common::{init_pool, Config};
use common::{init_pool, Config, JwtConfig};
use crate::state::AppState;
use crate::middleware;
use sqlx::PgPool;
@@ -47,8 +47,12 @@ pub async fn run() -> anyhow::Result<()> {
let pool = wait_for_db(&config.database_url).await;
let control_db_url = std::env::var("CONTROL_DB_URL")
.expect("CONTROL_DB_URL must be set");
let control_pool = wait_for_db(&control_db_url).await;
let app_state = AppState {
control_db: pool.clone(),
control_db: control_pool.clone(),
tenant_pools: Arc::new(RwLock::new(HashMap::new())),
};
@@ -65,7 +69,9 @@ pub async fn run() -> anyhow::Result<()> {
let data_state = data_api::handlers::DataState {
db: pool.clone(),
replica_pool: None,
config: config.clone(),
cache: Arc::new(data_api::schema_cache::SchemaCache::new()),
};
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
@@ -79,10 +85,10 @@ pub async fn run() -> anyhow::Result<()> {
let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone());
// Replication Listener
let repl_config = tenant_config.clone();
let repl_tx = realtime_state.broadcast_tx.clone();
let repl_state = realtime_state.clone();
let default_db_url = default_tenant_db_url.clone();
tokio::spawn(async move {
if let Err(e) = realtime::replication::start_replication_listener(repl_config, repl_tx).await {
if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await {
tracing::error!("Replication listener failed: {}", e);
}
});
@@ -96,23 +102,46 @@ pub async fn run() -> anyhow::Result<()> {
.expect("Failed to initialize WASM runtime")
);
let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new());
let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::<usize>().unwrap_or(4);
let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size));
let functions_state = functions::FunctionsState {
db: pool.clone(),
config: config.clone(),
runtime: functions_runtime,
deno_runtime,
deno_pool,
};
// Auth Middleware State
let jwt_config = JwtConfig::from_env()?;
let auth_middleware_state = auth::AuthMiddlewareState {
config: config.clone(),
jwt_config,
};
// Project Middleware State
let project_cache = moka::future::Cache::builder()
.max_capacity(100)
.time_to_live(Duration::from_secs(60))
.build();
let tenant_pools = moka::future::Cache::builder()
.max_capacity(100)
.time_to_idle(Duration::from_secs(300))
.build();
let project_middleware_state = middleware::ProjectMiddlewareState {
control_db: app_state.control_db.clone(),
tenant_pools: app_state.tenant_pools.clone(),
project_cache: moka::future::Cache::new(100),
control_db: control_pool.clone(),
tenant_pools: tenant_pools.clone(),
project_cache: project_cache.clone(),
realtime: realtime_state.clone(),
};
let project_middleware_state2 = middleware::ProjectMiddlewareState {
control_db: control_pool.clone(),
tenant_pools: tenant_pools.clone(),
project_cache: project_cache.clone(),
realtime: realtime_state,
};
// Construct Worker Routes
@@ -127,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
auth::auth_middleware,
))
.layer(from_fn_with_state(
project_middleware_state.clone(),
project_middleware_state2,
middleware::inject_tenant_pool,
))
.layer(from_fn_with_state(