use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; use std::time::Instant; use std::time::SystemTime; use std::time::UNIX_EPOCH; use std::collections::HashMap; use argon2::password_hash::PasswordHash; use argon2::password_hash::PasswordHasher; use argon2::password_hash::PasswordVerifier; use argon2::password_hash::SaltString; use argon2::Argon2; use axum::extract::Query; use axum::extract::State; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::get; use axum::routing::post; use axum::Json; use chrono::Utc; use hmac::Hmac; use hmac::Mac; use serde::Deserialize; use serde::Serialize; use sha1::Sha1; use sha2::Digest; use subtle::ConstantTimeEq; use thiserror::Error; use tokio::sync::Mutex; use crate::storage::GatewayStorage; use crate::storage::StorageError; use crate::AppState; pub fn router() -> axum::Router { axum::Router::new() .route("/signup", post(signup)) .route("/signin", post(signin)) .route("/service/signin", post(service_signin)) .route("/signout", post(signout)) .route("/refresh", post(refresh)) .route("/forgot", post(forgot)) .route("/reset", post(reset)) .route("/oidc/google/start", post(oidc_google_start)) .route("/oidc/google/callback", get(oidc_google_callback)) .route("/mfa/enroll/start", post(mfa_enroll_start)) .route("/mfa/enroll/confirm", post(mfa_enroll_confirm)) .route("/mfa/challenge", post(mfa_challenge)) } #[derive(Clone)] pub struct AuthnConfig { jwt_secrets: Arc>>, access_ttl: Duration, refresh_ttl: Duration, reset_ttl: Duration, } impl std::fmt::Debug for AuthnConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AuthnConfig").finish_non_exhaustive() } } impl AuthnConfig { pub fn from_env() -> Self { let jwt_secrets = read_jwt_secrets_from_env() .unwrap_or_else(|| vec![uuid::Uuid::new_v4().as_bytes().to_vec()]); let access_ttl = std::env::var("GATEWAY_ACCESS_TTL_SECS") .ok() .and_then(|v| v.parse::().ok()) .map(Duration::from_secs) .unwrap_or(Duration::from_secs(300)); let refresh_ttl = std::env::var("GATEWAY_REFRESH_TTL_SECS") .ok() .and_then(|v| v.parse::().ok()) .map(Duration::from_secs) .unwrap_or(Duration::from_secs(60 * 60 * 24 * 30)); let reset_ttl = std::env::var("GATEWAY_RESET_TTL_SECS") .ok() .and_then(|v| v.parse::().ok()) .map(Duration::from_secs) .unwrap_or(Duration::from_secs(60 * 15)); Self { jwt_secrets: Arc::new(jwt_secrets), access_ttl, refresh_ttl, reset_ttl, } } pub fn for_tests() -> Self { Self { jwt_secrets: Arc::new(vec![b"test-secret".to_vec()]), access_ttl: Duration::from_secs(300), refresh_ttl: Duration::from_secs(3600), reset_ttl: Duration::from_secs(60), } } pub fn verify_access_token(&self, token: &str) -> Result { let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256); validation.validate_exp = true; validation.validate_nbf = false; validation.leeway = 0; for secret in self.jwt_secrets.iter() { if let Ok(data) = jsonwebtoken::decode::( token, &jsonwebtoken::DecodingKey::from_secret(secret), &validation, ) { return Ok(data.claims); } } Err(VerifyError::InvalidToken) } } #[derive(Debug, Error, Clone, PartialEq, Eq)] pub enum VerifyError { #[error("invalid token")] InvalidToken, } #[derive(Debug, Deserialize)] pub struct SignupRequest { pub email: String, pub password: String, } #[derive(Debug, Deserialize)] pub struct SigninRequest { pub email: String, pub password: String, } #[derive(Debug, Deserialize)] pub struct ServiceSigninRequest { pub service_account_id: String, pub token: String, } #[derive(Debug, Deserialize)] pub struct SignoutRequest { pub session_id: String, } #[derive(Debug, Deserialize, Serialize)] pub struct RefreshRequest { pub session_id: String, pub refresh_token: String, } #[derive(Debug, Deserialize)] pub struct ForgotRequest { pub email: String, } #[derive(Debug, Deserialize)] pub struct ResetRequest { pub reset_token: String, pub new_password: String, } #[derive(Debug, Deserialize)] pub struct MfaEnrollStartRequest { pub user_id: String, } #[derive(Debug, Deserialize)] pub struct MfaEnrollConfirmRequest { pub user_id: String, pub code: String, } #[derive(Debug, Serialize, Deserialize)] pub struct AuthResponse { pub access_token: String, pub session_id: String, pub refresh_token: String, } #[derive(Debug, Serialize)] pub struct ForgotResponse { pub status: &'static str, } #[derive(Debug, Serialize)] pub struct OidcStartResponse { pub url: String, } #[derive(Debug, Serialize)] pub struct MfaEnrollStartResponse { pub secret_base32: String, } #[derive(Debug, Serialize)] pub struct MfaEnrollConfirmResponse { pub status: &'static str, pub recovery_codes: Vec, } #[derive(Debug, Deserialize)] pub struct MfaChallengeRequest { pub code: String, } #[derive(Debug, Serialize)] pub struct MfaChallengeResponse { pub status: &'static str, } #[derive(Debug, Serialize)] pub struct ResetResponse { pub status: &'static str, } #[derive(Debug, Serialize, Deserialize)] struct Stored { v: u32, data: T, } #[derive(Debug, Clone, Serialize, Deserialize)] struct UserRecord { user_id: String, email: String, enabled: bool, created_at_ms: i64, } #[derive(Debug, Clone, Serialize, Deserialize)] struct PasswordCredentialRecord { user_id: String, password_hash: String, updated_at_ms: i64, } #[derive(Debug, Clone, Serialize, Deserialize)] struct PasswordResetRecord { user_id: String, token_hash: String, created_at_ms: i64, expires_at_ms: i64, used_at_ms: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] struct OidcStateRecord { nonce: String, created_at_ms: i64, expires_at_ms: i64, redirect_uri: String, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ServiceTokenRecord { user_id: String, token_hash: String, created_at_ms: i64, rotated_at_ms: Option, enabled: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] struct TotpEnrollmentRecord { user_id: String, secret_base32: String, enabled: bool, created_at_ms: i64, recovery_hashes: Vec, } #[derive(Debug, Error)] enum AuthnError { #[error("invalid input")] InvalidInput, #[error("user already exists")] UserExists, #[error("invalid credentials")] InvalidCredentials, #[error("refresh token invalid")] InvalidRefresh, #[error("reset token invalid")] InvalidReset, #[error("reset token expired")] ResetExpired, #[error("reset token already used")] ResetUsed, #[error("mfa code invalid")] MfaInvalid, #[error("too many requests")] TooManyRequests, #[error("oidc not configured")] OidcNotConfigured, #[error("storage error: {0}")] Storage(String), } impl From for AuthnError { fn from(value: StorageError) -> Self { match value { StorageError::AlreadyExists => AuthnError::UserExists, StorageError::RefreshTokenInvalid => AuthnError::InvalidRefresh, StorageError::RefreshSessionExpired | StorageError::RefreshSessionRevoked => { AuthnError::InvalidRefresh } other => AuthnError::Storage(other.to_string()), } } } impl IntoResponse for AuthnError { fn into_response(self) -> axum::response::Response { let (status, msg) = match &self { AuthnError::InvalidInput => (StatusCode::BAD_REQUEST, self.to_string()), AuthnError::UserExists => (StatusCode::CONFLICT, self.to_string()), AuthnError::InvalidCredentials => (StatusCode::UNAUTHORIZED, self.to_string()), AuthnError::InvalidRefresh => (StatusCode::UNAUTHORIZED, self.to_string()), AuthnError::InvalidReset | AuthnError::ResetExpired | AuthnError::ResetUsed => { (StatusCode::BAD_REQUEST, self.to_string()) } AuthnError::MfaInvalid => (StatusCode::BAD_REQUEST, self.to_string()), AuthnError::TooManyRequests => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), AuthnError::OidcNotConfigured => (StatusCode::NOT_IMPLEMENTED, self.to_string()), AuthnError::Storage(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()), }; (status, msg).into_response() } } static RATE_LIMITER: OnceLock>>> = OnceLock::new(); async fn check_rate_limit(key: &str, max: usize, window: Duration) -> Result<(), AuthnError> { let limiter = RATE_LIMITER.get_or_init(|| Mutex::new(HashMap::new())); let mut guard = limiter.lock().await; let now = Instant::now(); let bucket = guard.entry(key.to_string()).or_default(); bucket.retain(|t| now.duration_since(*t) < window); if bucket.len() >= max { return Err(AuthnError::TooManyRequests); } bucket.push(now); Ok(()) } fn read_jwt_secrets_from_env() -> Option>> { if let Ok(path) = std::env::var("GATEWAY_JWT_SECRETS_FILE") { if let Some(raw) = read_secret_file(&path) { return split_secrets(&raw); } } if let Ok(v) = std::env::var("GATEWAY_JWT_SECRETS") { if let Some(secrets) = split_secrets(&v) { return Some(secrets); } } if let Ok(path) = std::env::var("GATEWAY_JWT_SECRET_FILE") { if let Some(raw) = read_secret_file(&path) { return Some(vec![raw.into_bytes()]); } } std::env::var("GATEWAY_JWT_SECRET") .ok() .map(|s| vec![s.into_bytes()]) } fn split_secrets(raw: &str) -> Option>> { let normalized = raw.replace('\n', ","); let secrets: Vec> = normalized .split(',') .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.as_bytes().to_vec()) .collect(); if secrets.is_empty() { None } else { Some(secrets) } } fn read_secret_file(path: &str) -> Option { std::fs::read_to_string(path) .ok() .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) } fn env_or_file(env_key: &str, file_env_key: &str) -> Option { if let Ok(path) = std::env::var(file_env_key) { if let Some(value) = read_secret_file(&path) { return Some(value); } } std::env::var(env_key).ok().filter(|s| !s.trim().is_empty()) } async fn signup( State(state): State, Json(req): Json, ) -> Result, AuthnError> { let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?; validate_password(&req.password)?; let user_id = uuid::Uuid::new_v4().to_string(); let now_ms = unix_ms(); let user = UserRecord { user_id: user_id.clone(), email: email.clone(), enabled: true, created_at_ms: now_ms, }; let email_hash = hash_stable("email", &email); state .storage .users .create( &user_by_email_hash_key(&email_hash), user_id.as_bytes().to_vec(), ) .await?; state .storage .users .create(&user_key(&user_id), encode_stored(&user)?) .await?; let password_hash = hash_password(&req.password)?; let cred = PasswordCredentialRecord { user_id: user_id.clone(), password_hash, updated_at_ms: now_ms, }; state .storage .password_credentials .create(&password_key(&user_id), encode_stored(&cred)?) .await?; let session = state .storage .create_refresh_session(&user_id, state.authn.refresh_ttl) .await?; let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?; Ok(Json(AuthResponse { access_token, session_id: session.session_id, refresh_token: session.refresh_token, })) } async fn signin( State(state): State, Json(req): Json, ) -> Result, AuthnError> { let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?; check_rate_limit(&format!("signin:{email}"), 10, Duration::from_secs(60)).await?; let email_hash = hash_stable("email", &email); let user_id_bytes = state .storage .users .get(&user_by_email_hash_key(&email_hash)) .await? .ok_or(AuthnError::InvalidCredentials)? .value; let user_id = String::from_utf8(user_id_bytes).map_err(|_| AuthnError::InvalidCredentials)?; let user_entry = state .storage .users .get(&user_key(&user_id)) .await? .ok_or(AuthnError::InvalidCredentials)?; let stored_user: Stored = serde_json::from_slice(&user_entry.value) .map_err(|e| AuthnError::Storage(e.to_string()))?; if !stored_user.data.enabled { return Err(AuthnError::InvalidCredentials); } let cred_entry = state .storage .password_credentials .get(&password_key(&user_id)) .await? .ok_or(AuthnError::InvalidCredentials)?; let stored_cred: Stored = serde_json::from_slice(&cred_entry.value) .map_err(|e| AuthnError::Storage(e.to_string()))?; if !verify_password(&req.password, &stored_cred.data.password_hash)? { return Err(AuthnError::InvalidCredentials); } let session = state .storage .create_refresh_session(&user_id, state.authn.refresh_ttl) .await?; let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?; Ok(Json(AuthResponse { access_token, session_id: session.session_id, refresh_token: session.refresh_token, })) } async fn service_signin( State(state): State, Json(req): Json, ) -> Result, AuthnError> { if req.service_account_id.trim().is_empty() || req.token.trim().is_empty() { return Err(AuthnError::InvalidInput); } check_rate_limit( &format!("service_signin:{}", req.service_account_id), 30, Duration::from_secs(60), ) .await?; let key = service_token_key(&req.service_account_id); let entry = state .storage .service_tokens .get(&key) .await? .ok_or(AuthnError::InvalidCredentials)?; let stored: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; if !stored.data.enabled { return Err(AuthnError::InvalidCredentials); } let presented_hash = hash_stable("service_token", &req.token); if !bool::from( presented_hash .as_bytes() .ct_eq(stored.data.token_hash.as_bytes()), ) { return Err(AuthnError::InvalidCredentials); } let user_entry = state .storage .users .get(&user_key(&req.service_account_id)) .await? .ok_or(AuthnError::InvalidCredentials)?; let stored_user: Stored = serde_json::from_slice(&user_entry.value) .map_err(|e| AuthnError::Storage(e.to_string()))?; if !stored_user.data.enabled { return Err(AuthnError::InvalidCredentials); } let session = state .storage .create_refresh_session(&req.service_account_id, state.authn.refresh_ttl) .await?; let access_token = issue_access_token(&state.authn, &req.service_account_id, &session.session_id)?; Ok(Json(AuthResponse { access_token, session_id: session.session_id, refresh_token: session.refresh_token, })) } async fn refresh( State(state): State, Json(req): Json, ) -> Result, AuthnError> { let new_refresh_token = state .storage .rotate_refresh_token(&req.session_id, &req.refresh_token) .await?; let user_id = session_user_id(&state.storage, &req.session_id).await?; let access_token = issue_access_token(&state.authn, &user_id, &req.session_id)?; Ok(Json(AuthResponse { access_token, session_id: req.session_id, refresh_token: new_refresh_token, })) } async fn signout( State(state): State, Json(req): Json, ) -> Result { state .storage .revoke_refresh_session(&req.session_id) .await?; Ok(StatusCode::NO_CONTENT) } async fn forgot( State(state): State, Json(req): Json, ) -> Result, AuthnError> { let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?; check_rate_limit(&format!("forgot:{email}"), 5, Duration::from_secs(60)).await?; let email_hash = hash_stable("email", &email); let user_id = state .storage .users .get(&user_by_email_hash_key(&email_hash)) .await? .and_then(|e| String::from_utf8(e.value).ok()); if let Some(user_id) = user_id { let _ = issue_password_reset_token(&state.storage, &user_id, state.authn.reset_ttl).await; } Ok(Json(ForgotResponse { status: "ok" })) } async fn reset( State(state): State, Json(req): Json, ) -> Result, AuthnError> { check_rate_limit("reset", 30, Duration::from_secs(60)).await?; validate_password(&req.new_password)?; reset_password(&state.storage, &req.reset_token, &req.new_password).await?; Ok(Json(ResetResponse { status: "ok" })) } async fn oidc_google_start( State(state): State, ) -> Result, AuthnError> { let client_id = env_or_file("GOOGLE_OIDC_CLIENT_ID", "GOOGLE_OIDC_CLIENT_ID_FILE") .ok_or(AuthnError::OidcNotConfigured)?; let redirect_uri = env_or_file("GOOGLE_OIDC_REDIRECT_URI", "GOOGLE_OIDC_REDIRECT_URI_FILE") .ok_or(AuthnError::OidcNotConfigured)?; let state_value = uuid::Uuid::new_v4().to_string(); let nonce = uuid::Uuid::new_v4().to_string(); let now_ms = unix_ms(); let expires_at_ms = now_ms + 10 * 60 * 1000; let record = OidcStateRecord { nonce: nonce.clone(), created_at_ms: now_ms, expires_at_ms, redirect_uri: redirect_uri.clone(), }; let key = oidc_state_key(&state_value); state .storage .identities .create(&key, encode_stored(&record)?) .await?; let url = format!( "https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile&state={}&nonce={}&access_type=offline&prompt=consent", urlencoding::encode(&client_id), urlencoding::encode(&redirect_uri), urlencoding::encode(&state_value), urlencoding::encode(&nonce), ); Ok(Json(OidcStartResponse { url })) } #[derive(Debug, Deserialize)] struct OidcCallbackQuery { code: String, state: String, } async fn oidc_google_callback( State(state): State, Query(q): Query, ) -> Result, AuthnError> { let client_id = env_or_file("GOOGLE_OIDC_CLIENT_ID", "GOOGLE_OIDC_CLIENT_ID_FILE") .ok_or(AuthnError::OidcNotConfigured)?; let client_secret = env_or_file( "GOOGLE_OIDC_CLIENT_SECRET", "GOOGLE_OIDC_CLIENT_SECRET_FILE", ) .ok_or(AuthnError::OidcNotConfigured)?; let state_key = oidc_state_key(&q.state); let entry = state .storage .identities .get(&state_key) .await? .ok_or(AuthnError::InvalidInput)?; let state_record: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; let now_ms = unix_ms(); if now_ms >= state_record.data.expires_at_ms { let _ = state.storage.identities.delete(&state_key).await; return Err(AuthnError::InvalidInput); } let _ = state.storage.identities.delete(&state_key).await; #[derive(Deserialize)] struct TokenResponse { id_token: String, } let client = reqwest::Client::new(); let token_resp = client .post("https://oauth2.googleapis.com/token") .form(&[ ("grant_type", "authorization_code"), ("code", q.code.as_str()), ("client_id", client_id.as_str()), ("client_secret", client_secret.as_str()), ("redirect_uri", state_record.data.redirect_uri.as_str()), ]) .send() .await .map_err(|e| AuthnError::Storage(e.to_string()))?; if !token_resp.status().is_success() { return Err(AuthnError::InvalidInput); } let token_body = token_resp .json::() .await .map_err(|e| AuthnError::Storage(e.to_string()))?; let claims = verify_google_id_token(&token_body.id_token, &client_id, &state_record.data.nonce).await?; let user_id = upsert_google_identity(&state.storage, &claims).await?; let session = state .storage .create_refresh_session(&user_id, state.authn.refresh_ttl) .await?; let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?; Ok(Json(AuthResponse { access_token, session_id: session.session_id, refresh_token: session.refresh_token, })) } #[derive(Debug, Deserialize)] struct GoogleIdClaims { sub: String, email: Option, email_verified: Option, nonce: Option, } async fn verify_google_id_token( id_token: &str, client_id: &str, expected_nonce: &str, ) -> Result { let header = jsonwebtoken::decode_header(id_token).map_err(|_| AuthnError::InvalidInput)?; let kid = header.kid.ok_or(AuthnError::InvalidInput)?; let jwks = reqwest::get("https://www.googleapis.com/oauth2/v3/certs") .await .map_err(|e| AuthnError::Storage(e.to_string()))? .json::() .await .map_err(|e| AuthnError::Storage(e.to_string()))?; let keys = jwks .get("keys") .and_then(|v| v.as_array()) .ok_or(AuthnError::InvalidInput)?; let mut n = None; let mut e = None; for k in keys { if k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()) { n = k.get("n").and_then(|v| v.as_str()).map(|s| s.to_string()); e = k.get("e").and_then(|v| v.as_str()).map(|s| s.to_string()); break; } } let n = n.ok_or(AuthnError::InvalidInput)?; let e = e.ok_or(AuthnError::InvalidInput)?; let decoding_key = jsonwebtoken::DecodingKey::from_rsa_components(&n, &e) .map_err(|e| AuthnError::Storage(e.to_string()))?; let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256); validation.set_audience(&[client_id]); validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]); let token_data = jsonwebtoken::decode::(id_token, &decoding_key, &validation) .map_err(|_| AuthnError::InvalidInput)?; let claims = token_data.claims; if claims.nonce.as_deref() != Some(expected_nonce) { return Err(AuthnError::InvalidInput); } Ok(claims) } async fn upsert_google_identity( storage: &GatewayStorage, claims: &GoogleIdClaims, ) -> Result { let identity_key = format!("v1/identities/google/{}", claims.sub); if let Some(entry) = storage.identities.get(&identity_key).await? { let user_id = String::from_utf8(entry.value).map_err(|_| AuthnError::InvalidInput)?; return Ok(user_id); } let email = claims.email.clone().ok_or(AuthnError::InvalidInput)?; if claims.email_verified != Some(true) { return Err(AuthnError::InvalidInput); } let email_hash = hash_stable("email", &email); let existing = storage .users .get(&user_by_email_hash_key(&email_hash)) .await? .and_then(|e| String::from_utf8(e.value).ok()); let user_id = if let Some(user_id) = existing { user_id } else { let user_id = uuid::Uuid::new_v4().to_string(); let now_ms = unix_ms(); let user = UserRecord { user_id: user_id.clone(), email: email.clone(), enabled: true, created_at_ms: now_ms, }; storage .users .create( &user_by_email_hash_key(&email_hash), user_id.as_bytes().to_vec(), ) .await?; storage .users .create(&user_key(&user_id), encode_stored(&user)?) .await?; user_id }; let _ = storage .identities .create(&identity_key, user_id.as_bytes().to_vec()) .await; Ok(user_id) } async fn mfa_enroll_start( State(state): State, principal: crate::authz::Principal, Json(req): Json, ) -> Result, AuthnError> { if req.user_id != principal.user_id { return Err(AuthnError::InvalidInput); } let secret_base32 = generate_totp_secret_base32(); let now_ms = unix_ms(); let record = TotpEnrollmentRecord { user_id: req.user_id.clone(), secret_base32: secret_base32.clone(), enabled: false, created_at_ms: now_ms, recovery_hashes: Vec::new(), }; state .storage .mfa .put(&totp_key(&principal.user_id), encode_stored(&record)?) .await?; Ok(Json(MfaEnrollStartResponse { secret_base32 })) } async fn mfa_enroll_confirm( State(state): State, principal: crate::authz::Principal, Json(req): Json, ) -> Result, AuthnError> { if req.user_id != principal.user_id { return Err(AuthnError::InvalidInput); } let entry = state .storage .mfa .get(&totp_key(&principal.user_id)) .await? .ok_or(AuthnError::InvalidInput)?; let mut stored: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; if stored.data.enabled { return Ok(Json(MfaEnrollConfirmResponse { status: "ok", recovery_codes: Vec::new(), })); } let now = SystemTime::now(); if !verify_totp_code(&stored.data.secret_base32, &req.code, now)? { return Err(AuthnError::MfaInvalid); } stored.data.enabled = true; let mut recovery_codes: Vec = Vec::new(); if stored.data.recovery_hashes.is_empty() { recovery_codes = generate_recovery_codes(10); stored.data.recovery_hashes = recovery_codes .iter() .map(|c| hash_stable("recovery", c)) .collect(); } let payload = serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?; state .storage .mfa .update(&totp_key(&principal.user_id), entry.revision, payload) .await?; Ok(Json(MfaEnrollConfirmResponse { status: "ok", recovery_codes, })) } async fn mfa_challenge( State(state): State, principal: crate::authz::Principal, Json(req): Json, ) -> Result, AuthnError> { let key = totp_key(&principal.user_id); for _ in 0..10 { let entry = state .storage .mfa .get(&key) .await? .ok_or(AuthnError::InvalidInput)?; let mut stored: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; if !stored.data.enabled { return Err(AuthnError::InvalidInput); } let now = SystemTime::now(); if req.code.chars().all(|c| c.is_ascii_digit()) && req.code.len() == 6 { if verify_totp_code(&stored.data.secret_base32, &req.code, now)? { return Ok(Json(MfaChallengeResponse { status: "ok" })); } metrics::counter!("gateway_mfa_fail_total", "kind" => "totp").increment(1); return Err(AuthnError::MfaInvalid); } let presented_hash = hash_stable("recovery", &req.code); if let Some(pos) = stored .data .recovery_hashes .iter() .position(|h| h.as_bytes().ct_eq(presented_hash.as_bytes()).into()) { stored.data.recovery_hashes.remove(pos); let payload = serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?; match state .storage .mfa .update(&key, entry.revision, payload) .await { Ok(_) => { return Ok(Json(MfaChallengeResponse { status: "ok" })); } Err(StorageError::CasMismatch) => continue, Err(e) => return Err(e.into()), } } metrics::counter!("gateway_mfa_fail_total", "kind" => "recovery").increment(1); return Err(AuthnError::MfaInvalid); } Err(AuthnError::Storage("mfa cas failed".to_string())) } async fn session_user_id(storage: &GatewayStorage, session_id: &str) -> Result { let entry = storage .refresh_sessions .get(&format!("v1/sessions/{session_id}")) .await? .ok_or(AuthnError::InvalidRefresh)?; let stored: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; Ok(stored.data.user_id) } async fn issue_password_reset_token( storage: &GatewayStorage, user_id: &str, ttl: Duration, ) -> Result { let token = uuid::Uuid::new_v4().to_string(); let token_hash = hash_stable("reset", &token); let now_ms = unix_ms(); let record = PasswordResetRecord { user_id: user_id.to_string(), token_hash: token_hash.clone(), created_at_ms: now_ms, expires_at_ms: now_ms + ttl.as_millis() as i64, used_at_ms: None, }; storage .password_resets .create(&reset_key(&token_hash), encode_stored(&record)?) .await?; Ok(token) } async fn reset_password( storage: &GatewayStorage, presented_token: &str, new_password: &str, ) -> Result<(), AuthnError> { let token_hash = hash_stable("reset", presented_token); let key = reset_key(&token_hash); for _ in 0..10 { let entry = storage .password_resets .get(&key) .await? .ok_or(AuthnError::InvalidReset)?; let mut stored: Stored = serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?; let now_ms = unix_ms(); if stored.data.used_at_ms.is_some() { return Err(AuthnError::ResetUsed); } if now_ms >= stored.data.expires_at_ms { return Err(AuthnError::ResetExpired); } stored.data.used_at_ms = Some(now_ms); let payload = serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?; match storage .password_resets .update(&key, entry.revision, payload) .await { Ok(_) => { let password_hash = hash_password(new_password)?; let cred_key = password_key(&stored.data.user_id); let cred_entry = storage .password_credentials .get(&cred_key) .await? .ok_or(AuthnError::InvalidReset)?; let mut stored_cred: Stored = serde_json::from_slice(&cred_entry.value) .map_err(|e| AuthnError::Storage(e.to_string()))?; stored_cred.data.password_hash = password_hash; stored_cred.data.updated_at_ms = now_ms; let cred_payload = serde_json::to_vec(&stored_cred) .map_err(|e| AuthnError::Storage(e.to_string()))?; storage .password_credentials .update(&cred_key, cred_entry.revision, cred_payload) .await?; revoke_all_refresh_sessions_for_user(storage, &stored.data.user_id).await?; return Ok(()); } Err(StorageError::CasMismatch) => continue, Err(e) => return Err(e.into()), } } Err(AuthnError::Storage("reset cas failed".to_string())) } async fn revoke_all_refresh_sessions_for_user( storage: &GatewayStorage, user_id: &str, ) -> Result<(), AuthnError> { let keys = storage.refresh_sessions.list_keys("v1/sessions/").await?; for key in keys { if let Some(session_id) = key.strip_prefix("v1/sessions/") { let entry = storage.refresh_sessions.get(&key).await?; let Some(entry) = entry else { continue; }; let stored: Stored = serde_json::from_slice(&entry.value) .map_err(|e| AuthnError::Storage(e.to_string()))?; if stored.data.user_id == user_id { let _ = storage.revoke_refresh_session(session_id).await; } } } Ok(()) } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct AccessClaims { pub sub: String, pub session_id: String, pub iat: i64, pub exp: i64, } fn issue_access_token( cfg: &AuthnConfig, user_id: &str, session_id: &str, ) -> Result { let now = Utc::now().timestamp(); let exp = now + cfg.access_ttl.as_secs() as i64; let claims = AccessClaims { sub: user_id.to_string(), session_id: session_id.to_string(), iat: now, exp, }; let token = jsonwebtoken::encode( &jsonwebtoken::Header::default(), &claims, &jsonwebtoken::EncodingKey::from_secret( cfg.jwt_secrets .first() .ok_or_else(|| AuthnError::Storage("no jwt secrets configured".to_string()))?, ), ) .map_err(|e| AuthnError::Storage(e.to_string()))?; Ok(token) } fn validate_password(password: &str) -> Result<(), AuthnError> { if password.len() < 8 { return Err(AuthnError::InvalidInput); } Ok(()) } fn normalize_email(email: &str) -> Option { let e = email.trim().to_ascii_lowercase(); if e.is_empty() || !e.contains('@') { return None; } Some(e) } fn unix_ms() -> i64 { SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default() .as_millis() as i64 } #[derive(Debug, Serialize)] struct StoredRef<'a, T> { v: u32, data: &'a T, } fn encode_stored(value: &T) -> Result, AuthnError> { serde_json::to_vec(&StoredRef { v: crate::storage::SCHEMA_VERSION, data: value, }) .map_err(|e| AuthnError::Storage(e.to_string())) } fn hash_stable(domain: &str, value: &str) -> String { let mut hasher = sha2::Sha256::new(); hasher.update(domain.as_bytes()); hasher.update([0u8]); hasher.update(value.as_bytes()); hex::encode(hasher.finalize()) } fn user_key(user_id: &str) -> String { format!("v1/users/{user_id}") } fn user_by_email_hash_key(email_hash: &str) -> String { format!("v1/users/by_email_hash/{email_hash}") } fn password_key(user_id: &str) -> String { format!("v1/password/{user_id}") } fn reset_key(token_hash: &str) -> String { format!("v1/resets/{token_hash}") } fn oidc_state_key(state: &str) -> String { let h = hash_stable("oidc_state", state); format!("v1/oidc/google/state/{h}") } fn service_token_key(user_id: &str) -> String { format!("v1/service_tokens/{user_id}") } fn totp_key(user_id: &str) -> String { format!("v1/totp/{user_id}") } fn hash_password(password: &str) -> Result { let salt = SaltString::generate(&mut rand_core::OsRng); let argon2 = Argon2::default(); let hash = argon2 .hash_password(password.as_bytes(), &salt) .map_err(|e| AuthnError::Storage(e.to_string()))? .to_string(); Ok(hash) } fn verify_password(password: &str, hash: &str) -> Result { let parsed = PasswordHash::new(hash).map_err(|e| AuthnError::Storage(e.to_string()))?; let argon2 = Argon2::default(); Ok(argon2.verify_password(password.as_bytes(), &parsed).is_ok()) } fn generate_totp_secret_base32() -> String { let raw = uuid::Uuid::new_v4().as_bytes().to_vec(); base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &raw) } fn generate_recovery_codes(count: usize) -> Vec { let mut out = Vec::with_capacity(count); let mut seen = std::collections::HashSet::new(); while out.len() < count { let code = uuid::Uuid::new_v4() .simple() .to_string() .chars() .take(8) .collect::() .to_ascii_uppercase(); if seen.insert(code.clone()) { out.push(code); } } out } fn verify_totp_code(secret_base32: &str, code: &str, now: SystemTime) -> Result { let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, secret_base32) .ok_or(AuthnError::InvalidInput)?; let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(); let step = 30u64; let counter = now_secs / step; for offset in [0i64, -1, 1] { let c = if offset.is_negative() { counter.saturating_sub(offset.unsigned_abs()) } else { counter.saturating_add(offset as u64) }; let expected = totp_code(&secret, c, 6)?; if expected.as_bytes().ct_eq(code.as_bytes()).into() { return Ok(true); } } Ok(false) } fn totp_code(secret: &[u8], counter: u64, digits: u32) -> Result { let mut msg = [0u8; 8]; msg.copy_from_slice(&counter.to_be_bytes()); let mut mac = Hmac::::new_from_slice(secret).map_err(|_| AuthnError::InvalidInput)?; mac.update(&msg); let result = mac.finalize().into_bytes(); let offset = (result[19] & 0x0f) as usize; let binary = ((result[offset] as u32 & 0x7f) << 24) | ((result[offset + 1] as u32) << 16) | ((result[offset + 2] as u32) << 8) | (result[offset + 3] as u32); let modulus = 10u32.pow(digits); let code = binary % modulus; Ok(format!("{:0width$}", code, width = digits as usize)) } #[cfg(test)] mod tests { use super::*; use tower::util::ServiceExt; #[test] fn password_hashing_and_verification_works() { let hash = hash_password("correct horse battery staple").unwrap(); assert!(hash.contains("$argon2id$")); assert!(verify_password("correct horse battery staple", &hash).unwrap()); assert!(!verify_password("wrong", &hash).unwrap()); } #[tokio::test] async fn refresh_rotation_invalidates_old_token() { let metrics = crate::observability::init_metrics_for_tests(); let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new( crate::routing::RoutingConfig::empty(), ))) .await .unwrap(); let storage = crate::storage::GatewayStorage::new_in_memory(); let authn = AuthnConfig::for_tests(); let app = crate::app(crate::AppState { metrics, routing, storage, authn, }); let signup_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/signup") .header("content-type", "application/json") .body(axum::body::Body::from( r#"{"email":"a@b.com","password":"password123"}"#, )) .unwrap(), ) .await .unwrap(); assert_eq!(signup_res.status(), StatusCode::OK); let body = axum::body::to_bytes(signup_res.into_body(), usize::MAX) .await .unwrap(); let created: AuthResponse = serde_json::from_slice(&body).unwrap(); let refresh_body = serde_json::to_vec(&RefreshRequest { session_id: created.session_id.clone(), refresh_token: created.refresh_token.clone(), }) .unwrap(); let refresh_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/refresh") .header("content-type", "application/json") .body(axum::body::Body::from(refresh_body)) .unwrap(), ) .await .unwrap(); assert_eq!(refresh_res.status(), StatusCode::OK); let body = axum::body::to_bytes(refresh_res.into_body(), usize::MAX) .await .unwrap(); let refreshed: AuthResponse = serde_json::from_slice(&body).unwrap(); let refresh_again_body = serde_json::to_vec(&RefreshRequest { session_id: created.session_id.clone(), refresh_token: created.refresh_token.clone(), }) .unwrap(); let refresh_again = app .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/refresh") .header("content-type", "application/json") .body(axum::body::Body::from(refresh_again_body)) .unwrap(), ) .await .unwrap(); assert_eq!(refresh_again.status(), StatusCode::UNAUTHORIZED); assert_ne!(refreshed.refresh_token, created.refresh_token); } #[tokio::test] async fn forgot_reset_token_is_one_time_and_expires() { let storage = GatewayStorage::new_in_memory(); let user_id = "u1"; let token = issue_password_reset_token(&storage, user_id, Duration::from_millis(1)) .await .unwrap(); tokio::time::sleep(Duration::from_millis(5)).await; let res = reset_password(&storage, &token, "password123").await; assert!(matches!(res, Err(AuthnError::ResetExpired))); let token2 = issue_password_reset_token(&storage, user_id, Duration::from_secs(60)) .await .unwrap(); let cred = PasswordCredentialRecord { user_id: user_id.to_string(), password_hash: hash_password("password123").unwrap(), updated_at_ms: unix_ms(), }; storage .password_credentials .create( &password_key(user_id), serde_json::to_vec(&Stored { v: 1, data: cred }).unwrap(), ) .await .unwrap(); reset_password(&storage, &token2, "newpassword123") .await .unwrap(); let again = reset_password(&storage, &token2, "anotherpassword123").await; assert!(matches!(again, Err(AuthnError::ResetUsed))); } #[test] fn totp_verification_accepts_valid_code() { let secret = generate_totp_secret_base32(); let now = SystemTime::now(); let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(); let counter = now_secs / 30; let secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &secret).unwrap(); let code = totp_code(&secret_bytes, counter, 6).unwrap(); assert!(verify_totp_code(&secret, &code, now).unwrap()); assert!(!verify_totp_code(&secret, "000000", now).unwrap()); } #[tokio::test] async fn mfa_enrollment_and_challenge_work_with_totp_and_recovery_codes() { let metrics = crate::observability::init_metrics_for_tests(); let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new( crate::routing::RoutingConfig::empty(), ))) .await .unwrap(); let storage = crate::storage::GatewayStorage::new_in_memory(); let authn = AuthnConfig::for_tests(); let app = crate::app(crate::AppState { metrics, routing, storage, authn: authn.clone(), }); let signup_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/signup") .header("content-type", "application/json") .body(axum::body::Body::from( r#"{"email":"mfa@b.com","password":"password123"}"#, )) .unwrap(), ) .await .unwrap(); let body = axum::body::to_bytes(signup_res.into_body(), usize::MAX) .await .unwrap(); let created: AuthResponse = serde_json::from_slice(&body).unwrap(); let claims = authn.verify_access_token(&created.access_token).unwrap(); let enroll_start_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/mfa/enroll/start") .header("authorization", format!("Bearer {}", created.access_token)) .header("content-type", "application/json") .body(axum::body::Body::from(format!( r#"{{"user_id":"{}"}}"#, claims.sub ))) .unwrap(), ) .await .unwrap(); assert_eq!(enroll_start_res.status(), StatusCode::OK); let body = axum::body::to_bytes(enroll_start_res.into_body(), usize::MAX) .await .unwrap(); let enroll_start: serde_json::Value = serde_json::from_slice(&body).unwrap(); let secret_base32 = enroll_start .get("secret_base32") .and_then(|v| v.as_str()) .unwrap() .to_string(); let now = SystemTime::now(); let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(); let counter = now_secs / 30; let secret_bytes = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &secret_base32).unwrap(); let code = totp_code(&secret_bytes, counter, 6).unwrap(); let enroll_confirm_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/mfa/enroll/confirm") .header("authorization", format!("Bearer {}", created.access_token)) .header("content-type", "application/json") .body(axum::body::Body::from(format!( r#"{{"user_id":"{}","code":"{}"}}"#, claims.sub, code ))) .unwrap(), ) .await .unwrap(); assert_eq!(enroll_confirm_res.status(), StatusCode::OK); let body = axum::body::to_bytes(enroll_confirm_res.into_body(), usize::MAX) .await .unwrap(); let enroll_confirm: serde_json::Value = serde_json::from_slice(&body).unwrap(); let recovery_codes: Vec = enroll_confirm .get("recovery_codes") .and_then(|v| v.as_array()) .unwrap() .iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(); assert_eq!(recovery_codes.len(), 10); let challenge_totp_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/mfa/challenge") .header("authorization", format!("Bearer {}", created.access_token)) .header("content-type", "application/json") .body(axum::body::Body::from(format!(r#"{{"code":"{}"}}"#, code))) .unwrap(), ) .await .unwrap(); assert_eq!(challenge_totp_res.status(), StatusCode::OK); let recovery = recovery_codes[0].clone(); let challenge_recovery_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/mfa/challenge") .header("authorization", format!("Bearer {}", created.access_token)) .header("content-type", "application/json") .body(axum::body::Body::from(format!( r#"{{"code":"{}"}}"#, recovery ))) .unwrap(), ) .await .unwrap(); assert_eq!(challenge_recovery_res.status(), StatusCode::OK); let challenge_recovery_again = app .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/mfa/challenge") .header("authorization", format!("Bearer {}", created.access_token)) .header("content-type", "application/json") .body(axum::body::Body::from(format!( r#"{{"code":"{}"}}"#, recovery ))) .unwrap(), ) .await .unwrap(); assert_eq!(challenge_recovery_again.status(), StatusCode::BAD_REQUEST); } #[tokio::test] async fn rate_limits_trigger_for_signin_and_errors_do_not_echo_secrets() { if let Some(lock) = RATE_LIMITER.get() { lock.lock().await.clear(); } let metrics = crate::observability::init_metrics_for_tests(); let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new( crate::routing::RoutingConfig::empty(), ))) .await .unwrap(); let storage = crate::storage::GatewayStorage::new_in_memory(); let authn = AuthnConfig::for_tests(); let app = crate::app(crate::AppState { metrics, routing, storage, authn, }); let signup_res = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/signup") .header("content-type", "application/json") .body(axum::body::Body::from( r#"{"email":"rl@b.com","password":"password123"}"#, )) .unwrap(), ) .await .unwrap(); assert_eq!(signup_res.status(), StatusCode::OK); for i in 0..11 { let resp = app .clone() .oneshot( axum::http::Request::builder() .method("POST") .uri("/v1/auth/signin") .header("content-type", "application/json") .body(axum::body::Body::from( r#"{"email":"rl@b.com","password":"supersecret"}"#, )) .unwrap(), ) .await .unwrap(); let status = resp.status(); let body = axum::body::to_bytes(resp.into_body(), usize::MAX) .await .unwrap(); let body_str = String::from_utf8_lossy(&body); assert!(!body_str.contains("supersecret")); if i < 10 { assert_eq!(status, StatusCode::UNAUTHORIZED); } else { assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); } } } }