445 lines
15 KiB
Rust
445 lines
15 KiB
Rust
use gateway::middleware;
|
|
use gateway::state::AppState;
|
|
|
|
use axum::{
|
|
extract::{Request, Query},
|
|
middleware::{from_fn, from_fn_with_state, Next},
|
|
response::{Response, IntoResponse},
|
|
routing::get,
|
|
Router,
|
|
};
|
|
use tower_http::services::{ServeDir, ServeFile};
|
|
use axum::http::StatusCode;
|
|
use axum_prometheus::PrometheusMetricLayer;
|
|
use common::{init_pool, Config, JwtConfig};
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use tokio::sync::RwLock;
|
|
use tower_governor::{governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer};
|
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
|
use tower_http::trace::TraceLayer;
|
|
use moka::future::Cache;
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
|
|
fn shared_http_client() -> &'static reqwest::Client {
|
|
static CLIENT: std::sync::OnceLock<reqwest::Client> = std::sync::OnceLock::new();
|
|
CLIENT.get_or_init(|| {
|
|
reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(30))
|
|
.pool_max_idle_per_host(10)
|
|
.build()
|
|
.unwrap()
|
|
})
|
|
}
|
|
|
|
async fn logs_proxy_handler(Query(params): Query<HashMap<String, String>>) -> impl IntoResponse {
|
|
let loki_url = std::env::var("LOKI_URL")
|
|
.unwrap_or_else(|_| "http://loki:3100".to_string());
|
|
let query_url = format!("{}/loki/api/v1/query_range", loki_url);
|
|
|
|
let resp = shared_http_client()
|
|
.get(&query_url)
|
|
.query(¶ms)
|
|
.send()
|
|
.await;
|
|
|
|
match resp {
|
|
Ok(r) => {
|
|
let status = StatusCode::from_u16(r.status().as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
let body = r.bytes().await.unwrap_or_default();
|
|
(status, body).into_response()
|
|
},
|
|
Err(e) => {
|
|
tracing::error!("Loki proxy error: {}", e);
|
|
(StatusCode::BAD_GATEWAY, e.to_string()).into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn log_headers(req: Request, next: Next) -> Response {
|
|
tracing::debug!("Request Headers: {:?}", req.headers());
|
|
next.run(req).await
|
|
}
|
|
|
|
// Dashboard handler removed in favor of direct SPA serving from /web
|
|
|
|
async fn wait_for_db(db_url: &str) -> sqlx::PgPool {
|
|
loop {
|
|
match init_pool(db_url).await {
|
|
Ok(pool) => return pool,
|
|
Err(e) => {
|
|
tracing::warn!("Database not ready yet, retrying in 2s: {}", e);
|
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn validate_configuration(config: &Config, pool: &sqlx::PgPool) -> anyhow::Result<()> {
|
|
tracing::info!("Validating configuration...");
|
|
|
|
// 1. Validate JWT secret format
|
|
if config.jwt_secret.len() < 32 {
|
|
anyhow::bail!(
|
|
"JWT_SECRET too short ({} chars, minimum 32 required)",
|
|
config.jwt_secret.len()
|
|
);
|
|
}
|
|
|
|
// 1.1 Validate JWT issuer
|
|
let jwt_issuer = std::env::var("JWT_ISSUER").unwrap_or_else(|_| "madbase".to_string());
|
|
tracing::info!(
|
|
jwt_issuer = %jwt_issuer,
|
|
"JWT issuer configured"
|
|
);
|
|
|
|
// 2. Validate JWT secret consistency with database
|
|
let row = sqlx::query_as::<_, (String, String, String, String)>(
|
|
r#"
|
|
SELECT name, jwt_secret, anon_key, service_role_key
|
|
FROM projects
|
|
WHERE name = 'default'
|
|
LIMIT 1
|
|
"#
|
|
)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
if let Some((name, jwt_secret, anon_key, service_role_key)) = row {
|
|
if jwt_secret != config.jwt_secret {
|
|
anyhow::bail!(
|
|
"JWT_SECRET mismatch between environment and database (project 'default')\n\
|
|
Environment: {}...\n\
|
|
Database: {}...\n\
|
|
Run 'scripts/setup_default_project.sh' to fix this.",
|
|
&config.jwt_secret[..8],
|
|
&jwt_secret[..8]
|
|
);
|
|
}
|
|
|
|
// Validate that anon_key and service_role_key are present
|
|
if anon_key.is_empty() {
|
|
anyhow::bail!("Project 'default' has empty anon_key");
|
|
}
|
|
if service_role_key.is_empty() {
|
|
anyhow::bail!("Project 'default' has empty service_role_key");
|
|
}
|
|
|
|
tracing::info!(
|
|
project_name = name,
|
|
jwt_secret_preview = &jwt_secret[..8],
|
|
anon_key_present = !anon_key.is_empty(),
|
|
service_role_key_present = !service_role_key.is_empty(),
|
|
"Project configuration validated"
|
|
);
|
|
} else {
|
|
anyhow::bail!(
|
|
"Default project not found in database. Run 'scripts/setup_default_project.sh' to create it."
|
|
);
|
|
}
|
|
|
|
tracing::info!("Configuration validation successful.");
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
// Load configuration
|
|
dotenvy::dotenv().ok();
|
|
let config = Config::new().expect("Failed to load configuration");
|
|
|
|
// Initialize tracing
|
|
let rust_log = std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into());
|
|
let is_json = std::env::var("LOG_FORMAT").ok().as_deref() == Some("json");
|
|
|
|
use tracing_subscriber::Layer;
|
|
|
|
let filter = tracing_subscriber::EnvFilter::new(&rust_log).boxed();
|
|
let fmt_layer = if is_json {
|
|
tracing_subscriber::fmt::layer().json().boxed()
|
|
} else {
|
|
tracing_subscriber::fmt::layer().boxed()
|
|
};
|
|
|
|
let otel_layer = if let Ok(otlp_endpoint) = std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT") {
|
|
use opentelemetry_otlp::WithExportConfig;
|
|
|
|
let tracer = opentelemetry_otlp::new_pipeline()
|
|
.tracing()
|
|
.with_exporter(opentelemetry_otlp::new_exporter().tonic().with_endpoint(otlp_endpoint))
|
|
.with_trace_config(opentelemetry_sdk::trace::config().with_resource(
|
|
opentelemetry_sdk::Resource::new(vec![opentelemetry::KeyValue::new("service.name", "madbase-gateway")])
|
|
))
|
|
.install_batch(opentelemetry_sdk::runtime::Tokio)
|
|
.expect("Failed to initialize OTLP tracer");
|
|
|
|
Some(tracing_opentelemetry::layer().with_tracer(tracer).boxed())
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let registry = tracing_subscriber::registry()
|
|
.with(filter)
|
|
.with(fmt_layer);
|
|
|
|
if let Some(otel) = otel_layer {
|
|
registry.with(otel).init();
|
|
} else {
|
|
registry.init();
|
|
}
|
|
|
|
tracing::info!("Starting MadBase Gateway v4.1 (Admin UI)...");
|
|
|
|
// Initialize Database (Control Plane / Main DB)
|
|
tracing::info!("Connecting to database at {}...", config.database_url);
|
|
let pool = wait_for_db(&config.database_url).await;
|
|
tracing::info!("Database connected successfully.");
|
|
|
|
// Run Migrations (Tenant only for Gateway)
|
|
tracing::info!("Running tenant database migrations...");
|
|
sqlx::migrate!("../migrations")
|
|
.run(&pool)
|
|
.await
|
|
.expect("Failed to run tenant migrations");
|
|
tracing::info!("Tenant migrations applied successfully.");
|
|
|
|
// Validate Configuration
|
|
validate_configuration(&config, &pool).await?;
|
|
|
|
let app_state = AppState {
|
|
control_db: pool.clone(),
|
|
tenant_pools: Arc::new(RwLock::new(HashMap::new())),
|
|
};
|
|
|
|
let session_manager = config.redis_url.as_ref().map(|url| {
|
|
let cache = common::CacheLayer::new(Some(url.clone()), 86400);
|
|
auth::SessionManager::new(cache, 86400)
|
|
});
|
|
|
|
let auth_state = auth::AuthState {
|
|
db: pool.clone(),
|
|
config: config.clone(),
|
|
session_manager,
|
|
};
|
|
|
|
let replica_pool = if let Ok(url) = std::env::var("READ_REPLICA_URL") {
|
|
tracing::info!("Connecting to read replica at {}...", url);
|
|
Some(wait_for_db(&url).await)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let data_state = data_api::handlers::DataState {
|
|
db: pool.clone(),
|
|
replica_pool,
|
|
config: config.clone(),
|
|
cache: Arc::new(data_api::schema_cache::SchemaCache::new()),
|
|
};
|
|
|
|
// Register DDL invalidation listener
|
|
let ddl_cache = data_state.cache.clone();
|
|
let ddl_pool = pool.clone();
|
|
tokio::spawn(async move {
|
|
let mut listener = match sqlx::postgres::PgListener::connect_with(&ddl_pool).await {
|
|
Ok(l) => l,
|
|
Err(e) => {
|
|
tracing::error!("Failed to connect PgListener: {}", e);
|
|
return;
|
|
}
|
|
};
|
|
if let Err(e) = listener.listen("madbase_schema_change").await {
|
|
tracing::error!("Failed to listen on madbase_schema_change: {}", e);
|
|
return;
|
|
}
|
|
tracing::info!("DDL invalidation listener started.");
|
|
while let Ok(notification) = listener.recv().await {
|
|
tracing::info!("Received DDL change notification: {}. Invalidating SchemaCache.", notification.payload());
|
|
ddl_cache.invalidate_all().await;
|
|
}
|
|
});
|
|
|
|
// Initialize Tenant Database (for Realtime)
|
|
let default_tenant_db_url = std::env::var("DEFAULT_TENANT_DB_URL")
|
|
.expect("DEFAULT_TENANT_DB_URL must be set");
|
|
tracing::info!("Connecting to default tenant database...");
|
|
let tenant_pool = wait_for_db(&default_tenant_db_url).await;
|
|
tracing::info!("Tenant Database connected successfully.");
|
|
|
|
let control_state = control_plane::ControlPlaneState {
|
|
db: pool.clone(),
|
|
tenant_db: tenant_pool.clone(),
|
|
server_manager: None,
|
|
};
|
|
|
|
let mut tenant_config = config.clone();
|
|
tenant_config.database_url = default_tenant_db_url.clone();
|
|
|
|
// Realtime Init
|
|
let (realtime_router, realtime_state) = realtime::init(tenant_pool.clone(), tenant_config.clone());
|
|
|
|
// Start Replication Listener (for default tenant)
|
|
let repl_state = realtime_state.clone();
|
|
let default_db_url = default_tenant_db_url.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = realtime::replication::start_replication_listener("default".to_string(), default_db_url, repl_state).await {
|
|
tracing::error!("Replication listener failed: {}", e);
|
|
}
|
|
});
|
|
|
|
// Storage Init
|
|
let storage_router = storage::init(pool.clone(), config.clone()).await;
|
|
|
|
// Functions Init
|
|
let functions_runtime = Arc::new(functions::runtime::WasmRuntime::new().expect("Failed to initialize WASM runtime"));
|
|
let deno_runtime = Arc::new(functions::deno_runtime::DenoRuntime::new());
|
|
let pool_size = std::env::var("DENO_POOL_SIZE").unwrap_or_else(|_| "4".to_string()).parse::<usize>().unwrap_or(4);
|
|
let deno_pool = Arc::new(functions::worker_pool::DenoPool::new(pool_size));
|
|
let functions_state = functions::FunctionsState {
|
|
db: pool.clone(),
|
|
config: config.clone(),
|
|
runtime: functions_runtime,
|
|
deno_runtime,
|
|
deno_pool,
|
|
};
|
|
|
|
// Auth Middleware State
|
|
let jwt_config = JwtConfig::from_env()?;
|
|
let auth_middleware_state = auth::AuthMiddlewareState {
|
|
config: config.clone(),
|
|
jwt_config,
|
|
};
|
|
|
|
// Project Middleware State
|
|
let project_middleware_state = middleware::ProjectMiddlewareState {
|
|
control_db: app_state.control_db.clone(),
|
|
tenant_pools: moka::future::Cache::builder()
|
|
.max_capacity(100)
|
|
.time_to_idle(Duration::from_secs(300))
|
|
.build(),
|
|
project_cache: moka::future::Cache::builder()
|
|
.max_capacity(100)
|
|
.time_to_live(Duration::from_secs(60))
|
|
.build(),
|
|
realtime: realtime_state,
|
|
};
|
|
|
|
// Construct App
|
|
// We apply `resolve_project` middleware to /auth, /rest, /storage, /realtime
|
|
// But NOT /platform (admin)
|
|
|
|
let tenant_routes = Router::new()
|
|
.nest(
|
|
"/auth/v1",
|
|
auth::router()
|
|
.layer(from_fn_with_state(
|
|
auth_middleware_state.clone(),
|
|
auth::auth_middleware,
|
|
))
|
|
.with_state(auth_state),
|
|
)
|
|
.nest(
|
|
"/rest/v1",
|
|
data_api::router()
|
|
.layer(from_fn_with_state(
|
|
auth_middleware_state.clone(),
|
|
auth::auth_middleware,
|
|
))
|
|
.with_state(data_state),
|
|
)
|
|
.nest("/realtime/v1", realtime_router)
|
|
.nest(
|
|
"/storage/v1",
|
|
storage_router.layer(from_fn_with_state(
|
|
auth_middleware_state.clone(),
|
|
auth::auth_middleware,
|
|
)),
|
|
)
|
|
.nest(
|
|
"/functions/v1",
|
|
functions::router(functions_state).layer(from_fn_with_state(
|
|
auth_middleware_state.clone(),
|
|
auth::auth_middleware,
|
|
)),
|
|
)
|
|
.layer(from_fn_with_state(
|
|
project_middleware_state.clone(),
|
|
middleware::inject_tenant_pool,
|
|
))
|
|
.layer(from_fn_with_state(
|
|
project_middleware_state,
|
|
middleware::resolve_project,
|
|
));
|
|
|
|
// Metrics
|
|
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
|
|
|
|
// Rate Limiting Configuration
|
|
let governor_conf = Arc::new(
|
|
GovernorConfigBuilder::default()
|
|
.per_second(config.rate_limit_per_second)
|
|
.burst_size(config.rate_limit_per_second as u32 * 2)
|
|
.key_extractor(SmartIpKeyExtractor)
|
|
.finish()
|
|
.unwrap(),
|
|
);
|
|
let app = Router::new()
|
|
.route("/", get(|| async { "Hello, MadBase!" }))
|
|
.route("/metrics", get(|| async move { metric_handle.render() }))
|
|
.nest("/", tenant_routes) // Apply project resolution to these
|
|
.nest(
|
|
"/platform/v1", // Admin/Control Plane API (No project resolution needed)
|
|
control_plane::router(control_state)
|
|
.route("/logs", get(logs_proxy_handler)),
|
|
)
|
|
.layer(GovernorLayer {
|
|
config: governor_conf,
|
|
})
|
|
.layer({
|
|
let origins_str = std::env::var("ALLOWED_ORIGINS")
|
|
.unwrap_or_else(|_| "http://localhost:3000,http://localhost:8000".to_string());
|
|
let origins: Vec<axum::http::HeaderValue> = origins_str
|
|
.split(',')
|
|
.filter_map(|s| s.trim().parse().ok())
|
|
.collect();
|
|
CorsLayer::new()
|
|
.allow_origin(AllowOrigin::list(origins))
|
|
.allow_methods([
|
|
axum::http::Method::GET,
|
|
axum::http::Method::POST,
|
|
axum::http::Method::PUT,
|
|
axum::http::Method::DELETE,
|
|
axum::http::Method::OPTIONS,
|
|
])
|
|
.allow_headers([
|
|
axum::http::header::CONTENT_TYPE,
|
|
axum::http::header::AUTHORIZATION,
|
|
axum::http::HeaderName::from_static("apikey"),
|
|
])
|
|
.allow_credentials(true)
|
|
})
|
|
.layer(TraceLayer::new_for_http())
|
|
.layer(from_fn(log_headers))
|
|
.layer(prometheus_layer)
|
|
.fallback_service(ServeDir::new("web").fallback(ServeFile::new("web/index.html")));
|
|
|
|
// Run it
|
|
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
|
|
tracing::info!("Listening on {}", addr);
|
|
|
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
|
|
|
let shutdown = async {
|
|
tokio::signal::ctrl_c().await.ok();
|
|
tracing::info!("Shutdown signal received, draining connections...");
|
|
};
|
|
|
|
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>())
|
|
.with_graceful_shutdown(shutdown)
|
|
.await?;
|
|
|
|
tracing::info!("Gateway Server shut down cleanly.");
|
|
Ok(())
|
|
}
|