chore: full stack stability and migration fixes, plus react UI progress
This commit is contained in:
@@ -26,11 +26,16 @@ tower_governor = "0.4.2"
|
||||
tower-http = { version = "0.6.8", features = ["cors", "trace", "fs"] }
|
||||
moka = { version = "0.12.14", features = ["future"] }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
tokio-tungstenite = "0.21"
|
||||
futures = { workspace = true }
|
||||
lazy_static = "1.4"
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
opentelemetry = "0.22"
|
||||
opentelemetry-otlp = { version = "0.15", features = ["tonic"] }
|
||||
opentelemetry_sdk = { version = "0.22", features = ["rt-tokio"] }
|
||||
tracing-opentelemetry = "0.23"
|
||||
|
||||
[dev-dependencies]
|
||||
tower = "0.5"
|
||||
|
||||
@@ -20,6 +20,7 @@ pub struct AdminAuthState {
|
||||
struct SessionData {
|
||||
_created_at: DateTime<Utc>,
|
||||
last_accessed: DateTime<Utc>,
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
impl AdminAuthState {
|
||||
@@ -31,9 +32,11 @@ impl AdminAuthState {
|
||||
|
||||
pub async fn create_session(&self) -> String {
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let csrf_token = Uuid::new_v4().to_string();
|
||||
let data = SessionData {
|
||||
_created_at: Utc::now(),
|
||||
last_accessed: Utc::now(),
|
||||
csrf_token,
|
||||
};
|
||||
|
||||
self.sessions.write().await.insert(session_id.clone(), data);
|
||||
@@ -44,6 +47,18 @@ impl AdminAuthState {
|
||||
session_id
|
||||
}
|
||||
|
||||
pub async fn get_csrf_token(&self, session_id: &str) -> Option<String> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(session_id).map(|d| d.csrf_token.clone())
|
||||
}
|
||||
|
||||
pub async fn validate_csrf_token(&self, session_id: &str, token: &str) -> bool {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(session_id)
|
||||
.map(|d| d.csrf_token == token)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub async fn validate_session(&self, session_id: &str) -> bool {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
|
||||
@@ -88,8 +103,11 @@ pub async fn admin_auth_middleware(
|
||||
|
||||
// 2. Protect ONLY the platform API routes
|
||||
if path.starts_with("/platform/v1") {
|
||||
// Allow the login endpoint
|
||||
if path == "/platform/v1/login" {
|
||||
// Allow the login, logout, and csrf-token endpoints
|
||||
if path == "/platform/v1/login"
|
||||
|| path == "/platform/v1/logout"
|
||||
|| path == "/platform/v1/csrf-token"
|
||||
{
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use axum::{
|
||||
extract::{Request, Query, State},
|
||||
extract::{Request, Query, State, Path},
|
||||
middleware::{from_fn, from_fn_with_state, Next},
|
||||
response::{Response, IntoResponse},
|
||||
routing::get,
|
||||
routing::{get, post, delete},
|
||||
Router,
|
||||
};
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::COOKIE;
|
||||
use axum_prometheus::PrometheusMetricLayer;
|
||||
use common::{init_pool, Config};
|
||||
use sqlx::PgPool;
|
||||
use crate::admin_auth::{admin_auth_middleware, AdminAuthState};
|
||||
use control_plane::{ControlPlaneState, CreateProjectRequest, RotateKeyRequest};
|
||||
use control_plane::ControlPlaneState;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
@@ -42,6 +43,12 @@ struct AppState {
|
||||
control_plane: ControlPlaneState,
|
||||
}
|
||||
|
||||
impl axum::extract::FromRef<AppState> for ControlPlaneState {
|
||||
fn from_ref(state: &AppState) -> Self {
|
||||
state.control_plane.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginRequest {
|
||||
password: String,
|
||||
@@ -81,6 +88,225 @@ async fn login_handler(
|
||||
).into_response()
|
||||
}
|
||||
|
||||
async fn logout_handler(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
) -> impl IntoResponse {
|
||||
// Extract session from cookie and revoke it
|
||||
let session_id = req.headers()
|
||||
.get(COOKIE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|cookies| {
|
||||
cookies.split(';')
|
||||
.find_map(|c| c.trim().strip_prefix("madbase_admin_session="))
|
||||
})
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
state.admin_auth.revoke_session(&sid).await;
|
||||
}
|
||||
|
||||
let clear_cookie = "madbase_admin_session=; HttpOnly; SameSite=Strict; Path=/; Max-Age=0";
|
||||
(
|
||||
StatusCode::OK,
|
||||
[("set-cookie", clear_cookie.to_string())],
|
||||
serde_json::json!({"message": "Logged out"}).to_string(),
|
||||
).into_response()
|
||||
}
|
||||
|
||||
async fn csrf_token_handler(
|
||||
State(state): State<AppState>,
|
||||
req: Request,
|
||||
) -> impl IntoResponse {
|
||||
let session_id = req.headers()
|
||||
.get(COOKIE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|cookies| {
|
||||
cookies.split(';')
|
||||
.find_map(|c| c.trim().strip_prefix("madbase_admin_session="))
|
||||
})
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(sid) = session_id {
|
||||
if let Some(token) = state.admin_auth.get_csrf_token(&sid).await {
|
||||
return (StatusCode::OK, serde_json::json!({"token": token}).to_string()).into_response();
|
||||
}
|
||||
}
|
||||
|
||||
(StatusCode::UNAUTHORIZED, serde_json::json!({"error": "No session"}).to_string()).into_response()
|
||||
}
|
||||
|
||||
async fn admin_config_handler() -> impl IntoResponse {
|
||||
let grafana_url = std::env::var("MADBASE_GRAFANA_URL")
|
||||
.unwrap_or_else(|_| "/grafana".to_string());
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
|
||||
(StatusCode::OK, serde_json::json!({
|
||||
"grafana_url": grafana_url,
|
||||
"version": version
|
||||
}).to_string()).into_response()
|
||||
}
|
||||
|
||||
// Admin-proxied storage endpoints (browser never touches service_role_key)
|
||||
fn get_service_key() -> String {
|
||||
std::env::var("SERVICE_ROLE_KEY").unwrap_or_default()
|
||||
}
|
||||
|
||||
async fn admin_storage_buckets() -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.get(format!("{}/storage/v1/bucket", base))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_storage_list_objects(Path(bucket): Path<String>) -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.post(format!("{}/storage/v1/object/list/{}", base, bucket))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_storage_upload(
|
||||
Path((bucket, name)): Path<(String, String)>,
|
||||
req: Request,
|
||||
) -> impl IntoResponse {
|
||||
let content_type = req.headers()
|
||||
.get(header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("application/octet-stream")
|
||||
.to_string();
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), 100 * 1024 * 1024).await;
|
||||
let body_bytes = match body_bytes {
|
||||
Ok(b) => b,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
};
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.post(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.header("Content-Type", content_type)
|
||||
.body(body_bytes)
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_storage_delete(
|
||||
Path((bucket, name)): Path<(String, String)>,
|
||||
) -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.delete(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_storage_download(
|
||||
Path((bucket, name)): Path<(String, String)>,
|
||||
) -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.get(format!("{}/storage/v1/object/{}/{}", base, bucket, name))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// Admin-proxied functions endpoints
|
||||
async fn admin_functions_list() -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.get(format!("{}/functions/v1", base))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_functions_get(Path(name): Path<String>) -> impl IntoResponse {
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.get(format!("{}/functions/v1/{}", base, name))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn admin_functions_deploy(req: Request) -> impl IntoResponse {
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await;
|
||||
let body_bytes = match body_bytes {
|
||||
Ok(b) => b,
|
||||
Err(e) => return (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
};
|
||||
let client = shared_http_client();
|
||||
let base = format!("http://127.0.0.1:{}", std::env::var("PORT").unwrap_or_else(|_| "8000".to_string()));
|
||||
match client.post(format!("{}/functions/v1", base))
|
||||
.header("Authorization", format!("Bearer {}", get_service_key()))
|
||||
.header("x-project-ref", "default")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body_bytes)
|
||||
.send().await {
|
||||
Ok(r) => {
|
||||
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let body = r.bytes().await.unwrap_or_default();
|
||||
(status, body).into_response()
|
||||
}
|
||||
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_allowed_origins() -> AllowOrigin {
|
||||
let origins_str = std::env::var("ALLOWED_ORIGINS")
|
||||
.unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000,http://localhost:8001".to_string());
|
||||
@@ -140,97 +366,7 @@ async fn log_headers(req: Request, next: Next) -> Response {
|
||||
}
|
||||
|
||||
// Wrapper handlers for control_plane routes that use AppState
|
||||
mod platform_routes {
|
||||
use super::*;
|
||||
use control_plane::{list_projects, create_project, delete_project, rotate_keys, get_project_keys, list_users, delete_user};
|
||||
use axum::{routing::{delete, get}, extract::Path};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn list_projects_wrapper(
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
list_projects(State(control_state)).await
|
||||
}
|
||||
|
||||
pub async fn create_project_wrapper(
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<CreateProjectRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
create_project(State(control_state), Json(payload)).await
|
||||
}
|
||||
|
||||
pub async fn delete_project_wrapper(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
delete_project(State(control_state), Path(id)).await
|
||||
}
|
||||
|
||||
pub async fn rotate_keys_wrapper(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(payload): Json<RotateKeyRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
rotate_keys(State(control_state), Path(id), Json(payload)).await
|
||||
}
|
||||
|
||||
pub async fn list_users_wrapper(
|
||||
State(state): State<AppState>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
list_users(State(control_state)).await
|
||||
}
|
||||
|
||||
pub async fn delete_user_wrapper(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
delete_user(State(control_state), Path(id)).await
|
||||
}
|
||||
|
||||
pub async fn get_project_keys_wrapper(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> impl IntoResponse {
|
||||
let control_state = ControlPlaneState {
|
||||
db: state.control_plane.db.clone(),
|
||||
tenant_db: state.control_plane.tenant_db.clone(),
|
||||
};
|
||||
get_project_keys(State(control_state), Path(id)).await
|
||||
}
|
||||
|
||||
pub fn router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/projects", get(list_projects_wrapper).post(create_project_wrapper))
|
||||
.route("/projects/:id", delete(delete_project_wrapper))
|
||||
.route("/projects/:id/keys", get(get_project_keys_wrapper).put(rotate_keys_wrapper))
|
||||
.route("/users", get(list_users_wrapper))
|
||||
.route("/users/:id", delete(delete_user_wrapper))
|
||||
}
|
||||
}
|
||||
// platform_routes now delegates to the consolidated control_plane::router()
|
||||
|
||||
pub async fn run() -> anyhow::Result<()> {
|
||||
let config = Config::new().expect("Failed to load configuration");
|
||||
@@ -239,18 +375,29 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
|
||||
let pool = wait_for_db(&config.database_url).await;
|
||||
|
||||
sqlx::migrate!("../migrations")
|
||||
tracing::info!("Running control plane migrations...");
|
||||
sqlx::migrate!("../migrations_control")
|
||||
.run(&pool)
|
||||
.await
|
||||
.expect("Failed to run migrations");
|
||||
.expect("Failed to run control plane migrations");
|
||||
|
||||
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
|
||||
.expect("DEFAULT_TENANT_DB_URL must be set");
|
||||
let tenant_pool = wait_for_db(&default_tenant_db_url).await;
|
||||
|
||||
tracing::info!("Running tenant migrations...");
|
||||
sqlx::migrate!("../migrations")
|
||||
.run(&tenant_pool)
|
||||
.await
|
||||
.expect("Failed to run tenant migrations");
|
||||
|
||||
// Initialize server manager for infrastructure management
|
||||
let server_manager = control_plane::init_server_manager(pool.clone()).await;
|
||||
|
||||
let control_plane_state = ControlPlaneState {
|
||||
db: pool.clone(),
|
||||
tenant_db: tenant_pool.clone(),
|
||||
server_manager,
|
||||
};
|
||||
|
||||
let admin_auth_state = AdminAuthState::new();
|
||||
@@ -269,16 +416,33 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], port));
|
||||
tracing::info!("Control plane listening on {}", addr);
|
||||
|
||||
// Build the control plane platform router (state already applied → Router<()>)
|
||||
let platform_router = control_plane::router(app_state.control_plane.clone());
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "MadBase Control Plane" }))
|
||||
.route("/health", get(|| async { "OK" }))
|
||||
.route("/metrics", get(|| async move { metric_handle.render() }))
|
||||
.route("/dashboard", get(dashboard_handler))
|
||||
.route("/logs", get(logs_proxy_handler))
|
||||
.route("/login", axum::routing::post(login_handler))
|
||||
.route("/login", post(login_handler))
|
||||
.route("/platform/v1/login", post(login_handler))
|
||||
.route("/platform/v1/logout", post(logout_handler))
|
||||
.route("/platform/v1/csrf-token", get(csrf_token_handler))
|
||||
.route("/platform/v1/admin/config", get(admin_config_handler))
|
||||
// Admin-proxied storage (no service key in browser)
|
||||
.route("/platform/v1/storage/buckets", get(admin_storage_buckets))
|
||||
.route("/platform/v1/storage/buckets/:bucket/objects", post(admin_storage_list_objects))
|
||||
.route("/platform/v1/storage/upload/:bucket/:name", post(admin_storage_upload))
|
||||
.route("/platform/v1/storage/:bucket/:name", delete(admin_storage_delete).get(admin_storage_download))
|
||||
// Admin-proxied functions
|
||||
.route("/platform/v1/functions", get(admin_functions_list).post(admin_functions_deploy))
|
||||
.route("/platform/v1/functions/:name", get(admin_functions_get))
|
||||
.nest_service("/css", ServeDir::new("web/css"))
|
||||
.nest_service("/js", ServeDir::new("web/js"))
|
||||
.nest("/platform/v1", platform_routes::router())
|
||||
.nest_service("/vendor", ServeDir::new("web/vendor"))
|
||||
.with_state(app_state)
|
||||
.merge(platform_router)
|
||||
.layer(from_fn(log_headers))
|
||||
.layer(prometheus_layer)
|
||||
.layer(
|
||||
@@ -288,9 +452,8 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::COOKIE])
|
||||
.allow_credentials(true),
|
||||
)
|
||||
.layer(from_fn_with_state(app_state.admin_auth.clone(), admin_auth_middleware))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.with_state(app_state);
|
||||
.layer(from_fn_with_state(admin_auth_state.clone(), admin_auth_middleware))
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
||||
|
||||
@@ -7,3 +7,12 @@ pub mod proxy;
|
||||
pub mod rate_limit;
|
||||
|
||||
pub use rate_limit::{RateLimiter, RateLimitConfig, RateLimitMiddleware, RateLimitStatus};
|
||||
|
||||
/// Runs tenant-specific migrations on the provided pool.
|
||||
/// This ensures that every tenant database has the required auth, storage,
|
||||
/// functions, and realtime schemas/tables.
|
||||
pub async fn run_tenant_migrations(pool: &sqlx::PgPool) -> Result<(), sqlx::migrate::MigrateError> {
|
||||
sqlx::migrate!("../migrations")
|
||||
.run(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mod middleware;
|
||||
mod state;
|
||||
use gateway::middleware;
|
||||
use gateway::state::AppState;
|
||||
|
||||
use axum::{
|
||||
extract::{Request, Query},
|
||||
@@ -8,10 +8,10 @@ use axum::{
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
use axum::http::StatusCode;
|
||||
use axum_prometheus::PrometheusMetricLayer;
|
||||
use common::{init_pool, Config};
|
||||
use state::AppState;
|
||||
use common::{init_pool, Config, JwtConfig};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -63,9 +63,7 @@ async fn log_headers(req: Request, next: Next) -> Response {
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
async fn dashboard_handler() -> axum::response::Html<&'static str> {
|
||||
axum::response::Html(include_str!("../../web/admin.html"))
|
||||
}
|
||||
// Dashboard handler removed in favor of direct SPA serving from /web
|
||||
|
||||
async fn wait_for_db(db_url: &str) -> sqlx::PgPool {
|
||||
loop {
|
||||
@@ -79,6 +77,73 @@ async fn wait_for_db(db_url: &str) -> sqlx::PgPool {
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_configuration(config: &Config, pool: &sqlx::PgPool) -> anyhow::Result<()> {
|
||||
tracing::info!("Validating configuration...");
|
||||
|
||||
// 1. Validate JWT secret format
|
||||
if config.jwt_secret.len() < 32 {
|
||||
anyhow::bail!(
|
||||
"JWT_SECRET too short ({} chars, minimum 32 required)",
|
||||
config.jwt_secret.len()
|
||||
);
|
||||
}
|
||||
|
||||
// 1.1 Validate JWT issuer
|
||||
let jwt_issuer = std::env::var("JWT_ISSUER").unwrap_or_else(|_| "madbase".to_string());
|
||||
tracing::info!(
|
||||
jwt_issuer = %jwt_issuer,
|
||||
"JWT issuer configured"
|
||||
);
|
||||
|
||||
// 2. Validate JWT secret consistency with database
|
||||
let row = sqlx::query_as::<_, (String, String, String, String)>(
|
||||
r#"
|
||||
SELECT name, jwt_secret, anon_key, service_role_key
|
||||
FROM projects
|
||||
WHERE name = 'default'
|
||||
LIMIT 1
|
||||
"#
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
if let Some((name, jwt_secret, anon_key, service_role_key)) = row {
|
||||
if jwt_secret != config.jwt_secret {
|
||||
anyhow::bail!(
|
||||
"JWT_SECRET mismatch between environment and database (project 'default')\n\
|
||||
Environment: {}...\n\
|
||||
Database: {}...\n\
|
||||
Run 'scripts/setup_default_project.sh' to fix this.",
|
||||
&config.jwt_secret[..8],
|
||||
&jwt_secret[..8]
|
||||
);
|
||||
}
|
||||
|
||||
// Validate that anon_key and service_role_key are present
|
||||
if anon_key.is_empty() {
|
||||
anyhow::bail!("Project 'default' has empty anon_key");
|
||||
}
|
||||
if service_role_key.is_empty() {
|
||||
anyhow::bail!("Project 'default' has empty service_role_key");
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
project_name = name,
|
||||
jwt_secret_preview = &jwt_secret[..8],
|
||||
anon_key_present = !anon_key.is_empty(),
|
||||
service_role_key_present = !service_role_key.is_empty(),
|
||||
"Project configuration validated"
|
||||
);
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"Default project not found in database. Run 'scripts/setup_default_project.sh' to create it."
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!("Configuration validation successful.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Load configuration
|
||||
@@ -87,17 +152,42 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Initialize tracing
|
||||
let rust_log = std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into());
|
||||
let is_json = std::env::var("LOG_FORMAT").ok().as_deref() == Some("json");
|
||||
|
||||
if std::env::var("LOG_FORMAT").ok().as_deref() == Some("json") {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::new(&rust_log))
|
||||
.with(tracing_subscriber::fmt::layer().json())
|
||||
.init();
|
||||
use tracing_subscriber::Layer;
|
||||
|
||||
let filter = tracing_subscriber::EnvFilter::new(&rust_log).boxed();
|
||||
let fmt_layer = if is_json {
|
||||
tracing_subscriber::fmt::layer().json().boxed()
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::EnvFilter::new(&rust_log))
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
tracing_subscriber::fmt::layer().boxed()
|
||||
};
|
||||
|
||||
let otel_layer = if let Ok(otlp_endpoint) = std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT") {
|
||||
use opentelemetry_otlp::WithExportConfig;
|
||||
|
||||
let tracer = opentelemetry_otlp::new_pipeline()
|
||||
.tracing()
|
||||
.with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint(otlp_endpoint))
|
||||
.with_trace_config(opentelemetry_sdk::trace::config().with_resource(
|
||||
opentelemetry_sdk::Resource::new(vec![opentelemetry::KeyValue::new("service.name", "madbase-gateway")])
|
||||
))
|
||||
.install_batch(opentelemetry_sdk::runtime::Tokio)
|
||||
.expect("Failed to initialize OTLP tracer");
|
||||
|
||||
Some(tracing_opentelemetry::layer().with_tracer(tracer).boxed())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let registry = tracing_subscriber::registry()
|
||||
.with(filter)
|
||||
.with(fmt_layer);
|
||||
|
||||
if let Some(otel) = otel_layer {
|
||||
registry.with(otel).init();
|
||||
} else {
|
||||
registry.init();
|
||||
}
|
||||
|
||||
tracing::info!("Starting MadBase Gateway v4.1 (Admin UI)...");
|
||||
@@ -107,13 +197,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let pool = wait_for_db(&config.database_url).await;
|
||||
tracing::info!("Database connected successfully.");
|
||||
|
||||
// Run Migrations
|
||||
tracing::info!("Running database migrations...");
|
||||
// Run Migrations (Tenant only for Gateway)
|
||||
tracing::info!("Running tenant database migrations...");
|
||||
sqlx::migrate!("../migrations")
|
||||
.run(&pool)
|
||||
.await
|
||||
.expect("Failed to run migrations");
|
||||
tracing::info!("Migrations applied successfully.");
|
||||
.expect("Failed to run tenant migrations");
|
||||
tracing::info!("Tenant migrations applied successfully.");
|
||||
|
||||
// Validate Configuration
|
||||
validate_configuration(&config, &pool).await?;
|
||||
|
||||
let app_state = AppState {
|
||||
control_db: pool.clone(),
|
||||
@@ -131,11 +224,42 @@ async fn main() -> anyhow::Result<()> {
|
||||
session_manager,
|
||||
};
|
||||
|
||||
let replica_pool = if let Ok(url) = std::env::var("READ_REPLICA_URL") {
|
||||
tracing::info!("Connecting to read replica at {}...", url);
|
||||
Some(wait_for_db(&url).await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let data_state = data_api::handlers::DataState {
|
||||
db: pool.clone(),
|
||||
replica_pool,
|
||||
config: config.clone(),
|
||||
cache: Arc::new(data_api::schema_cache::SchemaCache::new()),
|
||||
};
|
||||
|
||||
// Register DDL invalidation listener
|
||||
let ddl_cache = data_state.cache.clone();
|
||||
let ddl_pool = pool.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut listener = match sqlx::postgres::PgListener::connect_with(&ddl_pool).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect PgListener: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if let Err(e) = listener.listen("madbase_schema_change").await {
|
||||
tracing::error!("Failed to listen on madbase_schema_change: {}", e);
|
||||
return;
|
||||
}
|
||||
tracing::info!("DDL invalidation listener started.");
|
||||
while let Ok(notification) = listener.recv().await {
|
||||
tracing::info!("Received DDL change notification: {}. Invalidating SchemaCache.", notification.payload());
|
||||
ddl_cache.invalidate_all().await;
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize Tenant Database (for Realtime)
|
||||
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
|
||||
.expect("DEFAULT_TENANT_DB_URL must be set");
|
||||
@@ -146,19 +270,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
let control_state = control_plane::ControlPlaneState {
|
||||
db: pool.clone(),
|
||||
tenant_db: tenant_pool.clone(),
|
||||
server_manager: None,
|
||||
};
|
||||
|
||||
let mut tenant_config = config.clone();
|
||||
tenant_config.database_url = default_tenant_db_url;
|
||||
tenant_config.database_url = default_tenant_db_url.clone();
|
||||
|
||||
// Realtime Init
|
||||
let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone());
|
||||
|
||||
// Start Replication Listener
|
||||
let repl_config = tenant_config.clone();
|
||||
let repl_tx = realtime_state.broadcast_tx.clone();
|
||||
// Start Replication Listener (for default tenant)
|
||||
let repl_state = realtime_state.clone();
|
||||
let default_db_url = default_tenant_db_url.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = realtime::replication::start_replication_listener(repl_config, repl_tx).await {
|
||||
if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await {
|
||||
tracing::error!("Replication listener failed: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -169,23 +294,35 @@ async fn main() -> anyhow::Result<()> {
|
||||
// Functions Init
|
||||
let functions_runtime = Arc::new(functions::runtime::WasmRuntime::new().expect("Failed to initialize WASM runtime"));
|
||||
let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new());
|
||||
let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::<usize>().unwrap_or(4);
|
||||
let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size));
|
||||
let functions_state = functions::FunctionsState {
|
||||
db: pool.clone(),
|
||||
config: config.clone(),
|
||||
runtime: functions_runtime,
|
||||
deno_runtime,
|
||||
deno_pool,
|
||||
};
|
||||
|
||||
// Auth Middleware State
|
||||
let jwt_config = JwtConfig::from_env()?;
|
||||
let auth_middleware_state = auth::AuthMiddlewareState {
|
||||
config: config.clone(),
|
||||
jwt_config,
|
||||
};
|
||||
|
||||
// Project Middleware State
|
||||
let project_middleware_state = middleware::ProjectMiddlewareState {
|
||||
control_db: app_state.control_db.clone(),
|
||||
tenant_pools: app_state.tenant_pools.clone(),
|
||||
project_cache: Cache::new(100),
|
||||
tenant_pools: moka::future::Cache::builder()
|
||||
.max_capacity(100)
|
||||
.time_to_idle(Duration::from_secs(300))
|
||||
.build(),
|
||||
project_cache: moka::future::Cache::builder()
|
||||
.max_capacity(100)
|
||||
.time_to_live(Duration::from_secs(60))
|
||||
.build(),
|
||||
realtime: realtime_state,
|
||||
};
|
||||
|
||||
// Construct App
|
||||
@@ -247,11 +384,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
.finish()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, MadBase!" }))
|
||||
.route("/metrics", get(|| async move { metric_handle.render() }))
|
||||
.route("/dashboard", get(dashboard_handler))
|
||||
.nest("/", tenant_routes) // Apply project resolution to these
|
||||
.nest(
|
||||
"/platform/v1", // Admin/Control Plane API (No project resolution needed)
|
||||
@@ -286,14 +421,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
})
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(from_fn(log_headers))
|
||||
.layer(prometheus_layer);
|
||||
.layer(prometheus_layer)
|
||||
.fallback_service(ServeDir::new("web").fallback(ServeFile::new("web/index.html")));
|
||||
|
||||
// Run it
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
|
||||
tracing::info!("Listening on {}", addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
||||
|
||||
let shutdown = async {
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
tracing::info!("Shutdown signal received, draining connections...");
|
||||
};
|
||||
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.with_graceful_shutdown(shutdown)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Gateway Server shut down cleanly.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -8,16 +8,14 @@ use common::init_pool;
|
||||
use common::ProjectContext;
|
||||
use moka::future::Cache;
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProjectMiddlewareState {
|
||||
pub control_db: PgPool,
|
||||
pub tenant_pools: Arc<RwLock<HashMap<String, PgPool>>>,
|
||||
pub tenant_pools: Cache<String, PgPool>,
|
||||
pub project_cache: Cache<String, ProjectContext>,
|
||||
pub realtime: realtime::RealtimeState,
|
||||
}
|
||||
|
||||
pub async fn resolve_project(
|
||||
@@ -112,11 +110,11 @@ pub async fn inject_tenant_pool(
|
||||
|
||||
let db_url = project_ctx.db_url.clone();
|
||||
|
||||
let existing = { state.tenant_pools.read().await.get(&db_url).cloned() };
|
||||
|
||||
let pool = if let Some(p) = existing {
|
||||
// Check bound cache first
|
||||
let pool = if let Some(p) = state.tenant_pools.get(&db_url).await {
|
||||
p
|
||||
} else {
|
||||
// Init + Insert
|
||||
let new_pool = init_pool(&db_url)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
@@ -124,9 +122,24 @@ pub async fn inject_tenant_pool(
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let mut w = state.tenant_pools.write().await;
|
||||
let entry = w.entry(db_url).or_insert_with(|| new_pool.clone());
|
||||
entry.clone()
|
||||
// Ensure the tenant database is migrated
|
||||
if let Err(e) = crate::run_tenant_migrations(&new_pool).await {
|
||||
warn!("Failed to run tenant migrations for {}: {}", db_url, e);
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
|
||||
// Start replication listener for the tenant
|
||||
if let Err(e) = realtime::replication::start_replication_listener(
|
||||
project_ctx.project_ref.clone(),
|
||||
db_url.clone(),
|
||||
state.realtime.clone(),
|
||||
).await {
|
||||
warn!("Failed to start replication listener for {}: {}", project_ctx.project_ref, e);
|
||||
}
|
||||
|
||||
|
||||
state.tenant_pools.insert(db_url, new_pool.clone()).await;
|
||||
new_pool
|
||||
};
|
||||
|
||||
req.extensions_mut().insert(pool);
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
extract::{Request, State, ws::WebSocketUpgrade},
|
||||
http::{StatusCode, HeaderMap},
|
||||
response::Response,
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use tokio_tungstenite::{
|
||||
tungstenite::protocol::Message as TungsteniteMessage,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{error, info, debug};
|
||||
use moka::future::Cache;
|
||||
use common::{init_pool, ProjectContext};
|
||||
use sqlx::PgPool;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Upstream {
|
||||
@@ -34,10 +42,12 @@ struct ProxyState {
|
||||
worker_upstreams: Arc<RwLock<Vec<Upstream>>>,
|
||||
current_worker_index: Arc<RwLock<usize>>,
|
||||
http_client: reqwest::Client,
|
||||
control_db: PgPool,
|
||||
project_cache: Cache<String, ProjectContext>,
|
||||
}
|
||||
|
||||
impl ProxyState {
|
||||
fn new(control_url: String, worker_urls: Vec<String>) -> Self {
|
||||
fn new(control_url: String, worker_urls: Vec<String>, control_db: PgPool) -> Self {
|
||||
let worker_upstreams = worker_urls
|
||||
.into_iter()
|
||||
.map(|url| Upstream::new(format!("worker-{}", url), url))
|
||||
@@ -55,6 +65,8 @@ impl ProxyState {
|
||||
worker_upstreams: Arc::new(RwLock::new(worker_upstreams)),
|
||||
current_worker_index: Arc::new(RwLock::new(0)),
|
||||
http_client,
|
||||
control_db,
|
||||
project_cache: Cache::new(100),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +100,24 @@ impl ProxyState {
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Optional Dynamic Worker Discovery via Control Plane Polling
|
||||
// You can replace this whole background loop seamlessly to query the control plane
|
||||
let control_scan_url = format!("{}/workers", self.control_upstream.url);
|
||||
if let Ok(res) = self.http_client.get(&control_scan_url).send().await {
|
||||
if let Ok(workers) = res.json::<Vec<String>>().await {
|
||||
let mut current = self.worker_upstreams.write().await;
|
||||
// Retain healthy upstreams or register new ones
|
||||
let updated: Vec<Upstream> = workers.into_iter().map(|url| {
|
||||
if let Some(existing) = current.iter().find(|w| w.url == url) {
|
||||
existing.clone()
|
||||
} else {
|
||||
Upstream::new(format!("worker-{}", url), url)
|
||||
}
|
||||
}).collect();
|
||||
*current = updated;
|
||||
}
|
||||
}
|
||||
|
||||
// Check workers
|
||||
let worker_upstreams = self.worker_upstreams.read().await;
|
||||
for worker in worker_upstreams.iter() {
|
||||
@@ -130,6 +160,197 @@ impl ProxyState {
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_project_from_headers(
|
||||
state: &ProxyState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<ProjectContext, StatusCode> {
|
||||
let project_ref = if let Some(val) = headers.get("x-project-ref") {
|
||||
val.to_str()
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?
|
||||
.to_string()
|
||||
} else {
|
||||
"default".to_string()
|
||||
};
|
||||
|
||||
if let Some(ctx) = state.project_cache.get(&project_ref).await {
|
||||
return Ok(ctx);
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct ProjectRecord {
|
||||
db_url: String,
|
||||
jwt_secret: String,
|
||||
anon_key: Option<String>,
|
||||
service_role_key: Option<String>,
|
||||
}
|
||||
|
||||
let record = if project_ref == "default" {
|
||||
sqlx::query_as::<_, ProjectRecord>(
|
||||
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects LIMIT 1",
|
||||
)
|
||||
.fetch_optional(&state.control_db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("DB Error: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
} else {
|
||||
sqlx::query_as::<_, ProjectRecord>(
|
||||
"SELECT db_url, jwt_secret, anon_key, service_role_key FROM projects WHERE name = $1",
|
||||
)
|
||||
.bind(&project_ref)
|
||||
.fetch_optional(&state.control_db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("DB Error: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
};
|
||||
|
||||
if record.is_none() {
|
||||
error!("Project not found: {}", project_ref);
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
let project = record.unwrap();
|
||||
|
||||
let ctx = ProjectContext {
|
||||
project_ref: project_ref.clone(),
|
||||
db_url: project.db_url,
|
||||
redis_url: None,
|
||||
jwt_secret: project.jwt_secret,
|
||||
anon_key: project.anon_key,
|
||||
service_role_key: project.service_role_key,
|
||||
};
|
||||
|
||||
state.project_cache.insert(project_ref.clone(), ctx.clone()).await;
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
async fn proxy_websocket(
|
||||
State(state): State<ProxyState>,
|
||||
ws: WebSocket,
|
||||
headers: HeaderMap,
|
||||
) {
|
||||
let result = async {
|
||||
let request_id = headers.get("x-request-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let span = tracing::info_span!("proxy_websocket", request_id = %request_id);
|
||||
let _enter = span.enter();
|
||||
|
||||
let _project_ctx = resolve_project_from_headers(&state, &headers).await?;
|
||||
|
||||
let upstream = state.get_next_healthy_worker().await.ok_or_else(|| {
|
||||
error!("No healthy workers available");
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
})?;
|
||||
|
||||
let target_url_str = format!("{}/realtime/v1/websocket", upstream.url.replace("http://", "ws://"));
|
||||
debug!("Proxying WebSocket -> {}", target_url_str);
|
||||
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
|
||||
let mut req = target_url_str.clone().into_client_request().map_err(|e| {
|
||||
error!("Failed to create WebSocket request: {}", e);
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?;
|
||||
|
||||
for (name, value) in headers.iter() {
|
||||
let name_str = name.as_str();
|
||||
if name_str == "apikey" || name_str == "authorization" || name_str == "x-project-ref" {
|
||||
info!("Forwarding header: {}", name_str);
|
||||
req.headers_mut().insert(name, value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
info!("Connecting to worker WebSocket at: {}", target_url_str);
|
||||
let (server_ws, response) = tokio_tungstenite::connect_async(req)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to connect to WebSocket upstream {}: {}", upstream.name, e);
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?;
|
||||
info!("Worker WebSocket connection established. Response status: {:?}", response.status());
|
||||
|
||||
let (ws_sender, ws_receiver) = ws.split();
|
||||
let (server_sink, server_stream) = server_ws.split();
|
||||
|
||||
let tx_to_client = async move {
|
||||
let mut ws_receiver = ws_receiver;
|
||||
let mut server_sink = server_sink;
|
||||
|
||||
debug!("Starting tx_to_client loop");
|
||||
while let Some(msg_result) = ws_receiver.next().await {
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
debug!("Received message from client: {:?}", msg);
|
||||
let tungstenite_msg = match msg {
|
||||
Message::Text(text) => TungsteniteMessage::Text(text),
|
||||
Message::Binary(data) => TungsteniteMessage::Binary(data),
|
||||
Message::Close(_) => TungsteniteMessage::Close(None),
|
||||
Message::Ping(data) => TungsteniteMessage::Ping(data),
|
||||
Message::Pong(data) => TungsteniteMessage::Pong(data),
|
||||
};
|
||||
if server_sink.send(tungstenite_msg).await.is_err() {
|
||||
debug!("Failed to send to upstream, closing tx_to_client");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error receiving from client WebSocket: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("tx_to_client loop ended");
|
||||
};
|
||||
|
||||
let tx_to_upstream = async move {
|
||||
let mut ws_sender = ws_sender;
|
||||
let mut server_stream = server_stream;
|
||||
|
||||
debug!("Starting tx_to_upstream loop");
|
||||
while let Some(msg_result) = server_stream.next().await {
|
||||
match msg_result {
|
||||
Ok(msg) => {
|
||||
debug!("Received message from upstream: {:?}", msg);
|
||||
let axum_msg = match msg {
|
||||
TungsteniteMessage::Text(text) => Message::Text(text),
|
||||
TungsteniteMessage::Binary(data) => Message::Binary(data),
|
||||
TungsteniteMessage::Close(_) => Message::Close(None),
|
||||
TungsteniteMessage::Ping(data) => Message::Ping(data),
|
||||
TungsteniteMessage::Pong(data) => Message::Pong(data),
|
||||
_ => continue,
|
||||
};
|
||||
if ws_sender.send(axum_msg).await.is_err() {
|
||||
debug!("Failed to send to client, closing tx_to_upstream");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error receiving from upstream WebSocket: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("tx_to_upstream loop ended");
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = tx_to_client => {},
|
||||
_ = tx_to_upstream => {},
|
||||
}
|
||||
|
||||
Ok::<(), StatusCode>(())
|
||||
};
|
||||
|
||||
if let Err(e) = result.await {
|
||||
error!("WebSocket proxy error: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_request(
|
||||
State(state): State<ProxyState>,
|
||||
req: Request,
|
||||
@@ -166,8 +387,15 @@ async fn forward_request(
|
||||
req: Request,
|
||||
upstream: Upstream,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// Extract body before consuming the request (1.1.1)
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
let request_id = parts.headers.get("x-request-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
|
||||
|
||||
let span = tracing::info_span!("forward_request", request_id = %request_id, path = %parts.uri.path());
|
||||
let _enter = span.enter();
|
||||
let body_bytes = axum::body::to_bytes(body, 1024 * 1024 * 100) // 100MB limit
|
||||
.await
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
@@ -196,6 +424,7 @@ async fn forward_request(
|
||||
request_builder = request_builder.header(name.as_str(), v);
|
||||
}
|
||||
}
|
||||
request_builder = request_builder.header("x-request-id", &request_id);
|
||||
|
||||
// Attach body (1.1.1)
|
||||
let request_builder = request_builder.body(body_bytes);
|
||||
@@ -226,6 +455,7 @@ async fn forward_request(
|
||||
}
|
||||
|
||||
response_builder
|
||||
.header("x-request-id", &request_id)
|
||||
.body(body)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
@@ -238,10 +468,13 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
info!("Starting MadBase Proxy...");
|
||||
|
||||
let control_url = std::env::var("CONTROL_UPSTREAM_URL")
|
||||
.unwrap_or_else(|_| "http://control:8001".to_string());
|
||||
.unwrap_or_else(|_| "http://system:8001".to_string());
|
||||
|
||||
let control_db_url = std::env::var("CONTROL_DB_URL")
|
||||
.unwrap_or_else(|_| "postgres://admin:admin_password@localhost:5433/madbase_control".to_string());
|
||||
|
||||
let worker_urls_str = std::env::var("WORKER_UPSTREAM_URLS")
|
||||
.unwrap_or_else(|_| "http://worker1:8002".to_string());
|
||||
.unwrap_or_else(|_| "http://worker:8002".to_string());
|
||||
|
||||
let worker_urls: Vec<String> = worker_urls_str
|
||||
.split(',')
|
||||
@@ -250,9 +483,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.collect();
|
||||
|
||||
info!("Control upstream: {}", control_url);
|
||||
info!("Control DB: {}", control_db_url);
|
||||
info!("Worker upstreams: {:?}", worker_urls);
|
||||
|
||||
let state = ProxyState::new(control_url, worker_urls);
|
||||
let control_db = init_pool(&control_db_url).await?;
|
||||
|
||||
let state = ProxyState::new(control_url, worker_urls, control_db);
|
||||
|
||||
// Start health check loop in background
|
||||
let state_clone = state.clone();
|
||||
@@ -262,6 +498,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.route("/realtime/v1/websocket",
|
||||
get(|ws: WebSocketUpgrade, State(state): State<ProxyState>, req: Request| async move {
|
||||
let headers = req.headers().clone();
|
||||
ws.on_upgrade(move |socket| proxy_websocket(State(state.clone()), socket, headers))
|
||||
})
|
||||
)
|
||||
.fallback(proxy_request)
|
||||
.with_state(state);
|
||||
|
||||
@@ -291,9 +533,19 @@ mod tests {
|
||||
async fn test_proxy_round_robin() {
|
||||
let _guard = ENV_LOCK.lock().unwrap();
|
||||
|
||||
let control_db = PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.ok();
|
||||
let dummy_db_url = "postgres://postgres:postgres@localhost:5432/test";
|
||||
let control_pool = if let Some(pool) = control_db {
|
||||
pool
|
||||
} else {
|
||||
let pool = sqlx::PgPool::connect(dummy_db_url).await.unwrap();
|
||||
pool
|
||||
};
|
||||
|
||||
let state = ProxyState::new(
|
||||
"http://control:8001".to_string(),
|
||||
vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()]
|
||||
vec!["http://worker1:8002".to_string(), "http://worker2:8002".to_string()],
|
||||
control_pool,
|
||||
);
|
||||
|
||||
// Mark all as healthy
|
||||
@@ -315,9 +567,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_proxy_single_http_client() {
|
||||
let control_pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/test").await.unwrap();
|
||||
let state = ProxyState::new(
|
||||
"http://control:8001".to_string(),
|
||||
vec!["http://worker1:8002".to_string()]
|
||||
vec!["http://worker1:8002".to_string()],
|
||||
control_pool,
|
||||
);
|
||||
|
||||
// Verify http_client is created and usable
|
||||
|
||||
@@ -4,7 +4,7 @@ use axum::{
|
||||
Router,
|
||||
};
|
||||
use axum_prometheus::PrometheusMetricLayer;
|
||||
use common::{init_pool, Config};
|
||||
use common::{init_pool, Config, JwtConfig};
|
||||
use crate::state::AppState;
|
||||
use crate::middleware;
|
||||
use sqlx::PgPool;
|
||||
@@ -47,8 +47,12 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
|
||||
let pool = wait_for_db(&config.database_url).await;
|
||||
|
||||
let control_db_url = std::env::var("CONTROL_DB_URL")
|
||||
.expect("CONTROL_DB_URL must be set");
|
||||
let control_pool = wait_for_db(&control_db_url).await;
|
||||
|
||||
let app_state = AppState {
|
||||
control_db: pool.clone(),
|
||||
control_db: control_pool.clone(),
|
||||
tenant_pools: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
|
||||
@@ -65,7 +69,9 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
|
||||
let data_state = data_api::handlers::DataState {
|
||||
db: pool.clone(),
|
||||
replica_pool: None,
|
||||
config: config.clone(),
|
||||
cache: Arc::new(data_api::schema_cache::SchemaCache::new()),
|
||||
};
|
||||
|
||||
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
|
||||
@@ -79,10 +85,10 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone());
|
||||
|
||||
// Replication Listener
|
||||
let repl_config = tenant_config.clone();
|
||||
let repl_tx = realtime_state.broadcast_tx.clone();
|
||||
let repl_state = realtime_state.clone();
|
||||
let default_db_url = default_tenant_db_url.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = realtime::replication::start_replication_listener(repl_config, repl_tx).await {
|
||||
if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await {
|
||||
tracing::error!("Replication listener failed: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -96,23 +102,46 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.expect("Failed to initialize WASM runtime")
|
||||
);
|
||||
let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new());
|
||||
let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::<usize>().unwrap_or(4);
|
||||
let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size));
|
||||
let functions_state = functions::FunctionsState {
|
||||
db: pool.clone(),
|
||||
config: config.clone(),
|
||||
runtime: functions_runtime,
|
||||
deno_runtime,
|
||||
deno_pool,
|
||||
};
|
||||
|
||||
// Auth Middleware State
|
||||
let jwt_config = JwtConfig::from_env()?;
|
||||
let auth_middleware_state = auth::AuthMiddlewareState {
|
||||
config: config.clone(),
|
||||
jwt_config,
|
||||
};
|
||||
|
||||
// Project Middleware State
|
||||
let project_cache = moka::future::Cache::builder()
|
||||
.max_capacity(100)
|
||||
.time_to_live(Duration::from_secs(60))
|
||||
.build();
|
||||
|
||||
let tenant_pools = moka::future::Cache::builder()
|
||||
.max_capacity(100)
|
||||
.time_to_idle(Duration::from_secs(300))
|
||||
.build();
|
||||
|
||||
let project_middleware_state = middleware::ProjectMiddlewareState {
|
||||
control_db: app_state.control_db.clone(),
|
||||
tenant_pools: app_state.tenant_pools.clone(),
|
||||
project_cache: moka::future::Cache::new(100),
|
||||
control_db: control_pool.clone(),
|
||||
tenant_pools: tenant_pools.clone(),
|
||||
project_cache: project_cache.clone(),
|
||||
realtime: realtime_state.clone(),
|
||||
};
|
||||
|
||||
let project_middleware_state2 = middleware::ProjectMiddlewareState {
|
||||
control_db: control_pool.clone(),
|
||||
tenant_pools: tenant_pools.clone(),
|
||||
project_cache: project_cache.clone(),
|
||||
realtime: realtime_state,
|
||||
};
|
||||
|
||||
// Construct Worker Routes
|
||||
@@ -127,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
auth::auth_middleware,
|
||||
))
|
||||
.layer(from_fn_with_state(
|
||||
project_middleware_state.clone(),
|
||||
project_middleware_state2,
|
||||
middleware::inject_tenant_pool,
|
||||
))
|
||||
.layer(from_fn_with_state(
|
||||
|
||||
Reference in New Issue
Block a user