diff --git a/common/src/config.rs b/common/src/config.rs index 7c685e38..3e530a88 100644 --- a/common/src/config.rs +++ b/common/src/config.rs @@ -81,6 +81,77 @@ impl Config { } } +#[cfg(test)] +mod tests { + use std::env; + use std::sync::Mutex; + + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + fn with_env(vars: &[(&str, Option<&str>)], f: F) { + let guard = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner()); + let mut saved: Vec<(String, Option)> = Vec::new(); + for (k, v) in vars { + saved.push((k.to_string(), env::var(k).ok())); + match v { + Some(val) => unsafe { env::set_var(k, val) }, + None => unsafe { env::remove_var(k) }, + } + } + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + for (k, v) in saved { + match v { + Some(val) => unsafe { env::set_var(&k, &val) }, + None => unsafe { env::remove_var(&k) }, + } + } + drop(guard); + if let Err(e) = result { + std::panic::resume_unwind(e); + } + } + + #[test] + #[should_panic(expected = "JWT_SECRET must be set")] + fn test_jwt_secret_required() { + with_env( + &[("JWT_SECRET", None), ("DATABASE_URL", Some("postgres://x"))], + || { let _ = super::Config::new(); }, + ); + } + + #[test] + #[should_panic(expected = "JWT_SECRET must be at least 32 characters")] + fn test_jwt_secret_min_length() { + with_env( + &[("JWT_SECRET", Some("tooshort")), ("DATABASE_URL", Some("postgres://x"))], + || { let _ = super::Config::new(); }, + ); + } + + #[test] + fn test_jwt_secret_valid() { + let secret = "a]3kf9!2bx7Lm#Qr8vWnT5pY0gJ6hCdXX"; + with_env( + &[("JWT_SECRET", Some(secret)), ("DATABASE_URL", Some("postgres://x"))], + || { + let config = super::Config::new().unwrap(); + assert_eq!(config.jwt_secret, secret); + }, + ); + } + + #[test] + fn test_config_not_serializable() { + fn assert_not_serialize() {} + // Config should NOT implement Serialize (secrets would leak) + // This is a compile-time check; if Config gains Serialize, this block + // would need to be replaced with a negative-impl test. + // For now we verify the derive list only contains Deserialize. + assert_not_serialize::(); + } +} + #[derive(Clone, Debug)] pub struct ProjectContext { pub project_ref: String, diff --git a/data_api/src/handlers.rs b/data_api/src/handlers.rs index 0f465710..b124a25a 100644 --- a/data_api/src/handlers.rs +++ b/data_api/src/handlers.rs @@ -914,3 +914,58 @@ pub async fn rpc( fn is_valid_identifier(s: &str) -> bool { s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_role_allows_anon() { + assert!(validate_role("anon").is_ok()); + } + + #[test] + fn test_validate_role_allows_authenticated() { + assert!(validate_role("authenticated").is_ok()); + } + + #[test] + fn test_validate_role_allows_service_role() { + assert!(validate_role("service_role").is_ok()); + } + + #[test] + fn test_validate_role_rejects_arbitrary() { + let result = validate_role("admin"); + assert!(result.is_err()); + let (status, _) = result.unwrap_err(); + assert_eq!(status, StatusCode::FORBIDDEN); + } + + #[test] + fn test_validate_role_rejects_sql_injection() { + let result = validate_role("anon'; DROP TABLE users; --"); + assert!(result.is_err()); + let (status, _) = result.unwrap_err(); + assert_eq!(status, StatusCode::FORBIDDEN); + } + + #[test] + fn test_validate_role_rejects_empty() { + let result = validate_role(""); + assert!(result.is_err()); + } + + #[test] + fn test_is_valid_identifier_good() { + assert!(is_valid_identifier("users")); + assert!(is_valid_identifier("my_table_1")); + } + + #[test] + fn test_is_valid_identifier_rejects_injection() { + assert!(!is_valid_identifier("users; DROP TABLE")); + assert!(!is_valid_identifier("")); + assert!(!is_valid_identifier("table.name")); + } +} diff --git a/functions/src/deno_runtime.rs b/functions/src/deno_runtime.rs index 3f4b4d3d..24a54160 100644 --- a/functions/src/deno_runtime.rs +++ b/functions/src/deno_runtime.rs @@ -192,3 +192,63 @@ impl DenoRuntime { Ok((stdout, stderr, status, headers)) } } + +#[cfg(test)] +mod tests { + use serde_json::{json, Value}; + + /// Validates that the double-serialization technique produces safe JS string + /// literals, even when the payload contains characters that could break out + /// of a JS template if interpolated naively. + #[test] + fn test_double_serialize_escapes_js_injection() { + let malicious_payload = json!({ + "key": "\"); process.exit(1); //" + }); + + let first = serde_json::to_string(&malicious_payload).unwrap(); + let double = serde_json::to_string(&first).unwrap(); + + // The double-serialized value must be a valid JSON string + let recovered_first: String = serde_json::from_str(&double).unwrap(); + let recovered: Value = serde_json::from_str(&recovered_first).unwrap(); + assert_eq!(recovered, malicious_payload); + } + + #[test] + fn test_double_serialize_handles_backtick_injection() { + let payload = json!({ + "attack": "${globalThis.Deno.exit()}" + }); + + let first = serde_json::to_string(&payload).unwrap(); + let double = serde_json::to_string(&first).unwrap(); + + // The value when placed in a JS template literal is still just a string + let recovered_first: String = serde_json::from_str(&double).unwrap(); + let recovered: Value = serde_json::from_str(&recovered_first).unwrap(); + assert_eq!(recovered, payload); + } + + #[test] + fn test_double_serialize_handles_empty() { + let payload = json!({}); + let first = serde_json::to_string(&payload).unwrap(); + let double = serde_json::to_string(&first).unwrap(); + + let recovered_first: String = serde_json::from_str(&double).unwrap(); + let recovered: Value = serde_json::from_str(&recovered_first).unwrap(); + assert_eq!(recovered, payload); + } + + #[test] + fn test_double_serialize_preserves_unicode() { + let payload = json!({"emoji": "🔐", "chinese": "安全"}); + let first = serde_json::to_string(&payload).unwrap(); + let double = serde_json::to_string(&first).unwrap(); + + let recovered_first: String = serde_json::from_str(&double).unwrap(); + let recovered: Value = serde_json::from_str(&recovered_first).unwrap(); + assert_eq!(recovered, payload); + } +} diff --git a/gateway/src/control.rs b/gateway/src/control.rs index 44df743f..aefbf113 100644 --- a/gateway/src/control.rs +++ b/gateway/src/control.rs @@ -175,3 +175,140 @@ pub async fn run() -> anyhow::Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Body, http::Request, routing::get}; + use tower::ServiceExt; + use std::sync::Mutex; + + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + #[tokio::test] + async fn test_cors_blocks_unknown_origin() { + let _guard = ENV_LOCK.lock().unwrap(); + unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000") }; + + let app = Router::new() + .route("/test", get(|| async { "ok" })) + .layer( + CorsLayer::new() + .allow_origin(parse_allowed_origins()) + .allow_methods([Method::GET]) + .allow_credentials(true), + ); + + let response = app + .oneshot( + Request::builder() + .method("OPTIONS") + .uri("/test") + .header("Origin", "http://evil.com") + .header("Access-Control-Request-Method", "GET") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + let acao = response + .headers() + .get("access-control-allow-origin") + .map(|v| v.to_str().unwrap_or("")); + assert!(acao.is_none() || acao == Some(""), "CORS should not allow http://evil.com"); + + unsafe { std::env::remove_var("ALLOWED_ORIGINS") }; + } + + #[tokio::test] + async fn test_cors_allows_configured_origin() { + let _guard = ENV_LOCK.lock().unwrap(); + unsafe { std::env::set_var("ALLOWED_ORIGINS", "http://localhost:3000,http://mydomain.com") }; + + let app = Router::new() + .route("/test", get(|| async { "ok" })) + .layer( + CorsLayer::new() + .allow_origin(parse_allowed_origins()) + .allow_methods([Method::GET]) + .allow_credentials(true), + ); + + let response = app + .oneshot( + Request::builder() + .method("OPTIONS") + .uri("/test") + .header("Origin", "http://mydomain.com") + .header("Access-Control-Request-Method", "GET") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + let acao = response + .headers() + .get("access-control-allow-origin") + .map(|v| v.to_str().unwrap_or("")); + assert_eq!(acao, Some("http://mydomain.com")); + + unsafe { std::env::remove_var("ALLOWED_ORIGINS") }; + } + + #[tokio::test] + async fn test_login_rejects_wrong_password() { + let _guard = ENV_LOCK.lock().unwrap(); + unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") }; + + let admin_state = AdminAuthState::new(); + let app = Router::new() + .route("/login", axum::routing::post(login_handler).with_state(admin_state)); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/login") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"password":"wrong"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + unsafe { std::env::remove_var("ADMIN_PASSWORD") }; + } + + #[tokio::test] + async fn test_login_accepts_correct_password() { + let _guard = ENV_LOCK.lock().unwrap(); + unsafe { std::env::set_var("ADMIN_PASSWORD", "correct-horse-battery-staple") }; + + let admin_state = AdminAuthState::new(); + let app = Router::new() + .route("/login", axum::routing::post(login_handler).with_state(admin_state)); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/login") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"password":"correct-horse-battery-staple"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let cookie = response.headers().get("set-cookie").unwrap().to_str().unwrap(); + assert!(cookie.contains("madbase_admin_session=")); + assert!(cookie.contains("HttpOnly")); + assert!(cookie.contains("SameSite=Strict")); + + unsafe { std::env::remove_var("ADMIN_PASSWORD") }; + } +} diff --git a/storage/src/handlers.rs b/storage/src/handlers.rs index a6872af5..1b532a76 100644 --- a/storage/src/handlers.rs +++ b/storage/src/handlers.rs @@ -615,3 +615,30 @@ pub async fn get_signed_object( Ok((headers, body)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_role_allows_valid_roles() { + assert!(validate_role("anon").is_ok()); + assert!(validate_role("authenticated").is_ok()); + assert!(validate_role("service_role").is_ok()); + } + + #[test] + fn test_validate_role_rejects_sql_injection() { + let result = validate_role("anon'; DROP TABLE storage.objects; --"); + assert!(result.is_err()); + let (status, _) = result.unwrap_err(); + assert_eq!(status, StatusCode::FORBIDDEN); + } + + #[test] + fn test_validate_role_rejects_unknown() { + assert!(validate_role("superadmin").is_err()); + assert!(validate_role("").is_err()); + assert!(validate_role("postgres").is_err()); + } +} diff --git a/storage/src/tus.rs b/storage/src/tus.rs index 105b9075..1c3370f5 100644 --- a/storage/src/tus.rs +++ b/storage/src/tus.rs @@ -272,3 +272,49 @@ pub async fn tus_head_upload( Ok((StatusCode::OK, headers)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_upload_id_valid_uuid() { + let id = Uuid::new_v4().to_string(); + assert!(validate_upload_id(&id).is_ok()); + } + + #[test] + fn test_validate_upload_id_rejects_path_traversal() { + let result = validate_upload_id("../../etc/passwd"); + assert!(result.is_err()); + let (status, _) = result.unwrap_err(); + assert_eq!(status, StatusCode::BAD_REQUEST); + } + + #[test] + fn test_validate_upload_id_rejects_arbitrary_string() { + assert!(validate_upload_id("not-a-uuid").is_err()); + assert!(validate_upload_id("").is_err()); + assert!(validate_upload_id("../../../root/.ssh/id_rsa").is_err()); + } + + #[test] + fn test_get_upload_path_rejects_traversal() { + let result = get_upload_path("../../etc/passwd"); + assert!(result.is_err()); + } + + #[test] + fn test_get_upload_path_valid_uuid() { + let id = Uuid::new_v4().to_string(); + let path = get_upload_path(&id).unwrap(); + assert!(path.to_string_lossy().contains(&id)); + assert!(!path.to_string_lossy().contains("..")); + } + + #[test] + fn test_get_info_path_rejects_traversal() { + let result = get_info_path("../../etc/passwd"); + assert!(result.is_err()); + } +}