use axum::{ extract::State, http::StatusCode, response::{IntoResponse, Json}, Extension, }; use common::ProjectContext; use serde::{Deserialize, Serialize}; use sqlx::Row; use totp_rs::{Algorithm, Secret, TOTP}; use uuid::Uuid; use crate::middleware::AuthContext; use crate::handlers::AuthState; use crate::utils::{generate_token_with_aal, issue_refresh_token}; use crate::models::{User, AmrEntry}; #[derive(Serialize)] pub struct EnrollResponse { pub id: Uuid, pub type_: String, pub totp: TotpResponse, } #[derive(Serialize)] pub struct TotpResponse { pub qr_code: String, pub secret: String, pub uri: String, } #[derive(Deserialize)] pub struct MfaVerifyRequest { pub factor_id: Uuid, pub code: String, pub challenge_id: Option, } #[derive(Serialize)] pub struct VerifyResponse { pub access_token: String, pub token_type: String, pub expires_in: i64, pub refresh_token: String, pub user: User, } #[derive(Serialize)] pub struct ChallengeResponse { pub challenge_id: Uuid, pub expires_at: i64, } pub async fn enroll( State(state): State, Extension(auth_ctx): Extension, Extension(project_ctx): Extension, ) -> Result { let user_id = auth_ctx.claims.as_ref() .and_then(|c| Uuid::parse_str(&c.sub).ok()) .ok_or((StatusCode::UNAUTHORIZED, "Invalid user".to_string()))?; let secret = Secret::generate_secret(); let totp = TOTP::new( Algorithm::SHA1, 6, 1, 30, secret.to_bytes().unwrap(), Some(project_ctx.project_ref.clone()), auth_ctx.claims.as_ref().and_then(|c| c.email.clone()).unwrap_or("user".to_string()), ).unwrap(); let secret_str = totp.get_secret_base32(); let qr_code = totp.get_qr_base64().map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; let uri = totp.get_url(); let row = sqlx::query( "INSERT INTO auth.mfa_factors (user_id, factor_type, secret, status) VALUES ($1, 'totp', $2, 'unverified') RETURNING id" ) .bind(user_id) .bind(&secret_str) .fetch_one(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let factor_id: Uuid = row.get("id"); Ok(Json(EnrollResponse { id: factor_id, type_: "totp".to_string(), totp: TotpResponse { qr_code, secret: secret_str, uri, } })) } pub async fn verify( State(state): State, Extension(auth_ctx): Extension, Extension(project_ctx): Extension, Json(payload): Json, ) -> Result { let user_id = auth_ctx.claims.as_ref() .and_then(|c| Uuid::parse_str(&c.sub).ok()) .ok_or((StatusCode::UNAUTHORIZED, "Invalid user".to_string()))?; let row = sqlx::query( "SELECT secret, status FROM auth.mfa_factors WHERE id = $1 AND user_id = $2" ) .bind(payload.factor_id) .bind(user_id) .fetch_optional(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::NOT_FOUND, "Factor not found".to_string()))?; let secret_str: String = row.get("secret"); let status: String = row.get("status"); let secret_bytes = base32::decode(base32::Alphabet::RFC4648 { padding: false }, &secret_str) .ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Invalid secret format".to_string()))?; let totp = TOTP::new( Algorithm::SHA1, 6, 1, 30, secret_bytes, None, "".to_string(), ).unwrap(); let is_valid = totp.check_current(&payload.code).unwrap_or(false); if !is_valid { return Err((StatusCode::BAD_REQUEST, "Invalid code".to_string())); } if status == "unverified" { sqlx::query("UPDATE auth.mfa_factors SET status = 'verified', updated_at = now() WHERE id = $1") .bind(payload.factor_id) .execute(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; } let _challenge_id = if let Some(cid) = payload.challenge_id { let challenge_row = sqlx::query( "SELECT created_at FROM auth.mfa_challenges WHERE id = $1 AND factor_id = $2" ) .bind(cid) .bind(payload.factor_id) .fetch_optional(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::BAD_REQUEST, "Invalid challenge".to_string()))?; let created_at: chrono::DateTime = challenge_row.get("created_at"); let elapsed = chrono::Utc::now() - created_at; if elapsed.num_seconds() > 300 { return Err((StatusCode::BAD_REQUEST, "Challenge expired".to_string())); } sqlx::query("UPDATE auth.mfa_challenges SET verified_at = now() WHERE id = $1") .bind(cid) .execute(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; cid } else { Uuid::new_v4() }; let jwt_secret = project_ctx.jwt_secret.as_str(); let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1") .bind(user_id) .fetch_optional(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?; let amr = vec![ AmrEntry { method: "password".to_string(), timestamp: chrono::Utc::now().timestamp() as usize, }, AmrEntry { method: "totp".to_string(), timestamp: chrono::Utc::now().timestamp() as usize, }, ]; let (token, expires_in, _) = generate_token_with_aal( user_id, &user.email, "authenticated", jwt_secret, "aal2", Some(amr) ).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let refresh_token = issue_refresh_token(&state.db, user_id, Uuid::new_v4(), None).await .map_err(|(code, msg)| (StatusCode::from_u16(code.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), msg))?; Ok(Json(VerifyResponse { access_token: token, token_type: "bearer".to_string(), expires_in, refresh_token, user, })) } pub async fn challenge( State(state): State, Extension(auth_ctx): Extension, Json(payload): Json, ) -> Result { let user_id = auth_ctx.claims.as_ref() .and_then(|c| Uuid::parse_str(&c.sub).ok()) .ok_or((StatusCode::UNAUTHORIZED, "Invalid user".to_string()))?; let _row = sqlx::query( "SELECT id FROM auth.mfa_factors WHERE id = $1 AND user_id = $2 AND status = 'verified'" ) .bind(payload.factor_id) .bind(user_id) .fetch_optional(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .ok_or((StatusCode::BAD_REQUEST, "Factor not found or not verified".to_string()))?; let challenge_id = Uuid::new_v4(); sqlx::query( "INSERT INTO auth.mfa_challenges (id, factor_id, created_at) VALUES ($1, $2, now())" ) .bind(challenge_id) .bind(payload.factor_id) .execute(&state.db) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; let expires_at = chrono::Utc::now() + chrono::Duration::seconds(300); Ok(Json(ChallengeResponse { challenge_id, expires_at: expires_at.timestamp(), })) } #[cfg(test)] mod tests { use super::*; #[test] fn test_verify_response_structure() { let response = VerifyResponse { access_token: "test_token".to_string(), token_type: "bearer".to_string(), expires_in: 3600, refresh_token: "refresh".to_string(), user: User { id: Uuid::new_v4(), email: "test@example.com".to_string(), encrypted_password: "hash".to_string(), created_at: chrono::Utc::now(), updated_at: chrono::Utc::now(), last_sign_in_at: None, raw_app_meta_data: serde_json::json!({}), raw_user_meta_data: serde_json::json!({}), is_super_admin: None, confirmed_at: None, email_confirmed_at: None, phone: None, phone_confirmed_at: None, confirmation_token: None, recovery_token: None, email_change_token_new: None, email_change: None, deleted_at: None, }, }; assert_eq!(response.token_type, "bearer"); assert!(response.expires_in > 0); } #[test] fn test_challenge_response_structure() { let response = ChallengeResponse { challenge_id: Uuid::new_v4(), expires_at: 1234567890, }; assert!(response.expires_at > 0); } }