1708 lines
52 KiB
Rust
1708 lines
52 KiB
Rust
use std::sync::Arc;
|
|
use std::sync::OnceLock;
|
|
use std::time::Duration;
|
|
use std::time::Instant;
|
|
use std::time::SystemTime;
|
|
use std::time::UNIX_EPOCH;
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use argon2::password_hash::PasswordHash;
|
|
use argon2::password_hash::PasswordHasher;
|
|
use argon2::password_hash::PasswordVerifier;
|
|
use argon2::password_hash::SaltString;
|
|
use argon2::Argon2;
|
|
use axum::extract::Query;
|
|
use axum::extract::State;
|
|
use axum::http::StatusCode;
|
|
use axum::response::IntoResponse;
|
|
use axum::routing::get;
|
|
use axum::routing::post;
|
|
use axum::Json;
|
|
use chrono::Utc;
|
|
use hmac::Hmac;
|
|
use hmac::Mac;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use sha1::Sha1;
|
|
use sha2::Digest;
|
|
use subtle::ConstantTimeEq;
|
|
use thiserror::Error;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::storage::GatewayStorage;
|
|
use crate::storage::StorageError;
|
|
use crate::AppState;
|
|
|
|
pub fn router() -> axum::Router<AppState> {
|
|
axum::Router::new()
|
|
.route("/signup", post(signup))
|
|
.route("/signin", post(signin))
|
|
.route("/service/signin", post(service_signin))
|
|
.route("/signout", post(signout))
|
|
.route("/refresh", post(refresh))
|
|
.route("/forgot", post(forgot))
|
|
.route("/reset", post(reset))
|
|
.route("/oidc/google/start", post(oidc_google_start))
|
|
.route("/oidc/google/callback", get(oidc_google_callback))
|
|
.route("/mfa/enroll/start", post(mfa_enroll_start))
|
|
.route("/mfa/enroll/confirm", post(mfa_enroll_confirm))
|
|
.route("/mfa/challenge", post(mfa_challenge))
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct AuthnConfig {
|
|
jwt_secrets: Arc<Vec<Vec<u8>>>,
|
|
access_ttl: Duration,
|
|
refresh_ttl: Duration,
|
|
reset_ttl: Duration,
|
|
}
|
|
|
|
impl std::fmt::Debug for AuthnConfig {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("AuthnConfig").finish_non_exhaustive()
|
|
}
|
|
}
|
|
|
|
impl AuthnConfig {
|
|
pub fn from_env() -> Self {
|
|
let jwt_secrets = read_jwt_secrets_from_env()
|
|
.unwrap_or_else(|| vec![uuid::Uuid::new_v4().as_bytes().to_vec()]);
|
|
|
|
let access_ttl = std::env::var("GATEWAY_ACCESS_TTL_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u64>().ok())
|
|
.map(Duration::from_secs)
|
|
.unwrap_or(Duration::from_secs(300));
|
|
|
|
let refresh_ttl = std::env::var("GATEWAY_REFRESH_TTL_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u64>().ok())
|
|
.map(Duration::from_secs)
|
|
.unwrap_or(Duration::from_secs(60 * 60 * 24 * 30));
|
|
|
|
let reset_ttl = std::env::var("GATEWAY_RESET_TTL_SECS")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u64>().ok())
|
|
.map(Duration::from_secs)
|
|
.unwrap_or(Duration::from_secs(60 * 15));
|
|
|
|
Self {
|
|
jwt_secrets: Arc::new(jwt_secrets),
|
|
access_ttl,
|
|
refresh_ttl,
|
|
reset_ttl,
|
|
}
|
|
}
|
|
|
|
pub fn for_tests() -> Self {
|
|
Self {
|
|
jwt_secrets: Arc::new(vec![b"test-secret".to_vec()]),
|
|
access_ttl: Duration::from_secs(300),
|
|
refresh_ttl: Duration::from_secs(3600),
|
|
reset_ttl: Duration::from_secs(60),
|
|
}
|
|
}
|
|
|
|
pub fn verify_access_token(&self, token: &str) -> Result<AccessClaims, VerifyError> {
|
|
let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
|
|
validation.validate_exp = true;
|
|
validation.validate_nbf = false;
|
|
validation.leeway = 0;
|
|
|
|
for secret in self.jwt_secrets.iter() {
|
|
if let Ok(data) = jsonwebtoken::decode::<AccessClaims>(
|
|
token,
|
|
&jsonwebtoken::DecodingKey::from_secret(secret),
|
|
&validation,
|
|
) {
|
|
return Ok(data.claims);
|
|
}
|
|
}
|
|
|
|
Err(VerifyError::InvalidToken)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
|
pub enum VerifyError {
|
|
#[error("invalid token")]
|
|
InvalidToken,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct SignupRequest {
|
|
pub email: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct SigninRequest {
|
|
pub email: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ServiceSigninRequest {
|
|
pub service_account_id: String,
|
|
pub token: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct SignoutRequest {
|
|
pub session_id: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
pub struct RefreshRequest {
|
|
pub session_id: String,
|
|
pub refresh_token: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ForgotRequest {
|
|
pub email: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ResetRequest {
|
|
pub reset_token: String,
|
|
pub new_password: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct MfaEnrollStartRequest {
|
|
pub user_id: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct MfaEnrollConfirmRequest {
|
|
pub user_id: String,
|
|
pub code: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct AuthResponse {
|
|
pub access_token: String,
|
|
pub session_id: String,
|
|
pub refresh_token: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct ForgotResponse {
|
|
pub status: &'static str,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct OidcStartResponse {
|
|
pub url: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct MfaEnrollStartResponse {
|
|
pub secret_base32: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct MfaEnrollConfirmResponse {
|
|
pub status: &'static str,
|
|
pub recovery_codes: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct MfaChallengeRequest {
|
|
pub code: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct MfaChallengeResponse {
|
|
pub status: &'static str,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct ResetResponse {
|
|
pub status: &'static str,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct Stored<T> {
|
|
v: u32,
|
|
data: T,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct UserRecord {
|
|
user_id: String,
|
|
email: String,
|
|
enabled: bool,
|
|
created_at_ms: i64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct PasswordCredentialRecord {
|
|
user_id: String,
|
|
password_hash: String,
|
|
updated_at_ms: i64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct PasswordResetRecord {
|
|
user_id: String,
|
|
token_hash: String,
|
|
created_at_ms: i64,
|
|
expires_at_ms: i64,
|
|
used_at_ms: Option<i64>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct OidcStateRecord {
|
|
nonce: String,
|
|
created_at_ms: i64,
|
|
expires_at_ms: i64,
|
|
redirect_uri: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct ServiceTokenRecord {
|
|
user_id: String,
|
|
token_hash: String,
|
|
created_at_ms: i64,
|
|
rotated_at_ms: Option<i64>,
|
|
enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct TotpEnrollmentRecord {
|
|
user_id: String,
|
|
secret_base32: String,
|
|
enabled: bool,
|
|
created_at_ms: i64,
|
|
recovery_hashes: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
enum AuthnError {
|
|
#[error("invalid input")]
|
|
InvalidInput,
|
|
#[error("user already exists")]
|
|
UserExists,
|
|
#[error("invalid credentials")]
|
|
InvalidCredentials,
|
|
#[error("refresh token invalid")]
|
|
InvalidRefresh,
|
|
#[error("reset token invalid")]
|
|
InvalidReset,
|
|
#[error("reset token expired")]
|
|
ResetExpired,
|
|
#[error("reset token already used")]
|
|
ResetUsed,
|
|
#[error("mfa code invalid")]
|
|
MfaInvalid,
|
|
#[error("too many requests")]
|
|
TooManyRequests,
|
|
#[error("oidc not configured")]
|
|
OidcNotConfigured,
|
|
#[error("storage error: {0}")]
|
|
Storage(String),
|
|
}
|
|
|
|
impl From<StorageError> for AuthnError {
|
|
fn from(value: StorageError) -> Self {
|
|
match value {
|
|
StorageError::AlreadyExists => AuthnError::UserExists,
|
|
StorageError::RefreshTokenInvalid => AuthnError::InvalidRefresh,
|
|
StorageError::RefreshSessionExpired | StorageError::RefreshSessionRevoked => {
|
|
AuthnError::InvalidRefresh
|
|
}
|
|
other => AuthnError::Storage(other.to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for AuthnError {
|
|
fn into_response(self) -> axum::response::Response {
|
|
let (status, msg) = match &self {
|
|
AuthnError::InvalidInput => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
AuthnError::UserExists => (StatusCode::CONFLICT, self.to_string()),
|
|
AuthnError::InvalidCredentials => (StatusCode::UNAUTHORIZED, self.to_string()),
|
|
AuthnError::InvalidRefresh => (StatusCode::UNAUTHORIZED, self.to_string()),
|
|
AuthnError::InvalidReset | AuthnError::ResetExpired | AuthnError::ResetUsed => {
|
|
(StatusCode::BAD_REQUEST, self.to_string())
|
|
}
|
|
AuthnError::MfaInvalid => (StatusCode::BAD_REQUEST, self.to_string()),
|
|
AuthnError::TooManyRequests => (StatusCode::TOO_MANY_REQUESTS, self.to_string()),
|
|
AuthnError::OidcNotConfigured => (StatusCode::NOT_IMPLEMENTED, self.to_string()),
|
|
AuthnError::Storage(_) => (StatusCode::INTERNAL_SERVER_ERROR, self.to_string()),
|
|
};
|
|
(status, msg).into_response()
|
|
}
|
|
}
|
|
|
|
static RATE_LIMITER: OnceLock<Mutex<HashMap<String, Vec<Instant>>>> = OnceLock::new();
|
|
|
|
async fn check_rate_limit(key: &str, max: usize, window: Duration) -> Result<(), AuthnError> {
|
|
let limiter = RATE_LIMITER.get_or_init(|| Mutex::new(HashMap::new()));
|
|
let mut guard = limiter.lock().await;
|
|
let now = Instant::now();
|
|
|
|
let bucket = guard.entry(key.to_string()).or_default();
|
|
bucket.retain(|t| now.duration_since(*t) < window);
|
|
if bucket.len() >= max {
|
|
return Err(AuthnError::TooManyRequests);
|
|
}
|
|
bucket.push(now);
|
|
Ok(())
|
|
}
|
|
|
|
fn read_jwt_secrets_from_env() -> Option<Vec<Vec<u8>>> {
|
|
if let Ok(path) = std::env::var("GATEWAY_JWT_SECRETS_FILE") {
|
|
if let Some(raw) = read_secret_file(&path) {
|
|
return split_secrets(&raw);
|
|
}
|
|
}
|
|
|
|
if let Ok(v) = std::env::var("GATEWAY_JWT_SECRETS") {
|
|
if let Some(secrets) = split_secrets(&v) {
|
|
return Some(secrets);
|
|
}
|
|
}
|
|
|
|
if let Ok(path) = std::env::var("GATEWAY_JWT_SECRET_FILE") {
|
|
if let Some(raw) = read_secret_file(&path) {
|
|
return Some(vec![raw.into_bytes()]);
|
|
}
|
|
}
|
|
|
|
std::env::var("GATEWAY_JWT_SECRET")
|
|
.ok()
|
|
.map(|s| vec![s.into_bytes()])
|
|
}
|
|
|
|
fn split_secrets(raw: &str) -> Option<Vec<Vec<u8>>> {
|
|
let normalized = raw.replace('\n', ",");
|
|
let secrets: Vec<Vec<u8>> = normalized
|
|
.split(',')
|
|
.map(|s| s.trim())
|
|
.filter(|s| !s.is_empty())
|
|
.map(|s| s.as_bytes().to_vec())
|
|
.collect();
|
|
if secrets.is_empty() {
|
|
None
|
|
} else {
|
|
Some(secrets)
|
|
}
|
|
}
|
|
|
|
fn read_secret_file(path: &str) -> Option<String> {
|
|
std::fs::read_to_string(path)
|
|
.ok()
|
|
.map(|s| s.trim().to_string())
|
|
.filter(|s| !s.is_empty())
|
|
}
|
|
|
|
fn env_or_file(env_key: &str, file_env_key: &str) -> Option<String> {
|
|
if let Ok(path) = std::env::var(file_env_key) {
|
|
if let Some(value) = read_secret_file(&path) {
|
|
return Some(value);
|
|
}
|
|
}
|
|
std::env::var(env_key).ok().filter(|s| !s.trim().is_empty())
|
|
}
|
|
|
|
async fn signup(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<SignupRequest>,
|
|
) -> Result<Json<AuthResponse>, AuthnError> {
|
|
let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?;
|
|
validate_password(&req.password)?;
|
|
|
|
let user_id = uuid::Uuid::new_v4().to_string();
|
|
let now_ms = unix_ms();
|
|
|
|
let user = UserRecord {
|
|
user_id: user_id.clone(),
|
|
email: email.clone(),
|
|
enabled: true,
|
|
created_at_ms: now_ms,
|
|
};
|
|
|
|
let email_hash = hash_stable("email", &email);
|
|
state
|
|
.storage
|
|
.users
|
|
.create(
|
|
&user_by_email_hash_key(&email_hash),
|
|
user_id.as_bytes().to_vec(),
|
|
)
|
|
.await?;
|
|
state
|
|
.storage
|
|
.users
|
|
.create(&user_key(&user_id), encode_stored(&user)?)
|
|
.await?;
|
|
|
|
let password_hash = hash_password(&req.password)?;
|
|
let cred = PasswordCredentialRecord {
|
|
user_id: user_id.clone(),
|
|
password_hash,
|
|
updated_at_ms: now_ms,
|
|
};
|
|
state
|
|
.storage
|
|
.password_credentials
|
|
.create(&password_key(&user_id), encode_stored(&cred)?)
|
|
.await?;
|
|
|
|
let session = state
|
|
.storage
|
|
.create_refresh_session(&user_id, state.authn.refresh_ttl)
|
|
.await?;
|
|
let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?;
|
|
|
|
Ok(Json(AuthResponse {
|
|
access_token,
|
|
session_id: session.session_id,
|
|
refresh_token: session.refresh_token,
|
|
}))
|
|
}
|
|
|
|
async fn signin(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<SigninRequest>,
|
|
) -> Result<Json<AuthResponse>, AuthnError> {
|
|
let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?;
|
|
check_rate_limit(&format!("signin:{email}"), 10, Duration::from_secs(60)).await?;
|
|
let email_hash = hash_stable("email", &email);
|
|
|
|
let user_id_bytes = state
|
|
.storage
|
|
.users
|
|
.get(&user_by_email_hash_key(&email_hash))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidCredentials)?
|
|
.value;
|
|
let user_id = String::from_utf8(user_id_bytes).map_err(|_| AuthnError::InvalidCredentials)?;
|
|
|
|
let user_entry = state
|
|
.storage
|
|
.users
|
|
.get(&user_key(&user_id))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidCredentials)?;
|
|
let stored_user: Stored<UserRecord> = serde_json::from_slice(&user_entry.value)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
if !stored_user.data.enabled {
|
|
return Err(AuthnError::InvalidCredentials);
|
|
}
|
|
|
|
let cred_entry = state
|
|
.storage
|
|
.password_credentials
|
|
.get(&password_key(&user_id))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidCredentials)?;
|
|
let stored_cred: Stored<PasswordCredentialRecord> =
|
|
serde_json::from_slice(&cred_entry.value)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
if !verify_password(&req.password, &stored_cred.data.password_hash)? {
|
|
return Err(AuthnError::InvalidCredentials);
|
|
}
|
|
|
|
let session = state
|
|
.storage
|
|
.create_refresh_session(&user_id, state.authn.refresh_ttl)
|
|
.await?;
|
|
let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?;
|
|
|
|
Ok(Json(AuthResponse {
|
|
access_token,
|
|
session_id: session.session_id,
|
|
refresh_token: session.refresh_token,
|
|
}))
|
|
}
|
|
|
|
async fn service_signin(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<ServiceSigninRequest>,
|
|
) -> Result<Json<AuthResponse>, AuthnError> {
|
|
if req.service_account_id.trim().is_empty() || req.token.trim().is_empty() {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
check_rate_limit(
|
|
&format!("service_signin:{}", req.service_account_id),
|
|
30,
|
|
Duration::from_secs(60),
|
|
)
|
|
.await?;
|
|
|
|
let key = service_token_key(&req.service_account_id);
|
|
let entry = state
|
|
.storage
|
|
.service_tokens
|
|
.get(&key)
|
|
.await?
|
|
.ok_or(AuthnError::InvalidCredentials)?;
|
|
let stored: Stored<ServiceTokenRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
if !stored.data.enabled {
|
|
return Err(AuthnError::InvalidCredentials);
|
|
}
|
|
|
|
let presented_hash = hash_stable("service_token", &req.token);
|
|
if !bool::from(
|
|
presented_hash
|
|
.as_bytes()
|
|
.ct_eq(stored.data.token_hash.as_bytes()),
|
|
) {
|
|
return Err(AuthnError::InvalidCredentials);
|
|
}
|
|
|
|
let user_entry = state
|
|
.storage
|
|
.users
|
|
.get(&user_key(&req.service_account_id))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidCredentials)?;
|
|
let stored_user: Stored<UserRecord> = serde_json::from_slice(&user_entry.value)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
if !stored_user.data.enabled {
|
|
return Err(AuthnError::InvalidCredentials);
|
|
}
|
|
|
|
let session = state
|
|
.storage
|
|
.create_refresh_session(&req.service_account_id, state.authn.refresh_ttl)
|
|
.await?;
|
|
let access_token =
|
|
issue_access_token(&state.authn, &req.service_account_id, &session.session_id)?;
|
|
|
|
Ok(Json(AuthResponse {
|
|
access_token,
|
|
session_id: session.session_id,
|
|
refresh_token: session.refresh_token,
|
|
}))
|
|
}
|
|
|
|
async fn refresh(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<RefreshRequest>,
|
|
) -> Result<Json<AuthResponse>, AuthnError> {
|
|
let new_refresh_token = state
|
|
.storage
|
|
.rotate_refresh_token(&req.session_id, &req.refresh_token)
|
|
.await?;
|
|
|
|
let user_id = session_user_id(&state.storage, &req.session_id).await?;
|
|
let access_token = issue_access_token(&state.authn, &user_id, &req.session_id)?;
|
|
|
|
Ok(Json(AuthResponse {
|
|
access_token,
|
|
session_id: req.session_id,
|
|
refresh_token: new_refresh_token,
|
|
}))
|
|
}
|
|
|
|
async fn signout(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<SignoutRequest>,
|
|
) -> Result<StatusCode, AuthnError> {
|
|
state
|
|
.storage
|
|
.revoke_refresh_session(&req.session_id)
|
|
.await?;
|
|
Ok(StatusCode::NO_CONTENT)
|
|
}
|
|
|
|
async fn forgot(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<ForgotRequest>,
|
|
) -> Result<Json<ForgotResponse>, AuthnError> {
|
|
let email = normalize_email(&req.email).ok_or(AuthnError::InvalidInput)?;
|
|
check_rate_limit(&format!("forgot:{email}"), 5, Duration::from_secs(60)).await?;
|
|
let email_hash = hash_stable("email", &email);
|
|
let user_id = state
|
|
.storage
|
|
.users
|
|
.get(&user_by_email_hash_key(&email_hash))
|
|
.await?
|
|
.and_then(|e| String::from_utf8(e.value).ok());
|
|
|
|
if let Some(user_id) = user_id {
|
|
let _ = issue_password_reset_token(&state.storage, &user_id, state.authn.reset_ttl).await;
|
|
}
|
|
|
|
Ok(Json(ForgotResponse { status: "ok" }))
|
|
}
|
|
|
|
async fn reset(
|
|
State(state): State<AppState>,
|
|
Json(req): Json<ResetRequest>,
|
|
) -> Result<Json<ResetResponse>, AuthnError> {
|
|
check_rate_limit("reset", 30, Duration::from_secs(60)).await?;
|
|
validate_password(&req.new_password)?;
|
|
reset_password(&state.storage, &req.reset_token, &req.new_password).await?;
|
|
Ok(Json(ResetResponse { status: "ok" }))
|
|
}
|
|
|
|
async fn oidc_google_start(
|
|
State(state): State<AppState>,
|
|
) -> Result<Json<OidcStartResponse>, AuthnError> {
|
|
let client_id = env_or_file("GOOGLE_OIDC_CLIENT_ID", "GOOGLE_OIDC_CLIENT_ID_FILE")
|
|
.ok_or(AuthnError::OidcNotConfigured)?;
|
|
let redirect_uri = env_or_file("GOOGLE_OIDC_REDIRECT_URI", "GOOGLE_OIDC_REDIRECT_URI_FILE")
|
|
.ok_or(AuthnError::OidcNotConfigured)?;
|
|
|
|
let state_value = uuid::Uuid::new_v4().to_string();
|
|
let nonce = uuid::Uuid::new_v4().to_string();
|
|
let now_ms = unix_ms();
|
|
let expires_at_ms = now_ms + 10 * 60 * 1000;
|
|
|
|
let record = OidcStateRecord {
|
|
nonce: nonce.clone(),
|
|
created_at_ms: now_ms,
|
|
expires_at_ms,
|
|
redirect_uri: redirect_uri.clone(),
|
|
};
|
|
|
|
let key = oidc_state_key(&state_value);
|
|
state
|
|
.storage
|
|
.identities
|
|
.create(&key, encode_stored(&record)?)
|
|
.await?;
|
|
|
|
let url = format!(
|
|
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile&state={}&nonce={}&access_type=offline&prompt=consent",
|
|
urlencoding::encode(&client_id),
|
|
urlencoding::encode(&redirect_uri),
|
|
urlencoding::encode(&state_value),
|
|
urlencoding::encode(&nonce),
|
|
);
|
|
|
|
Ok(Json(OidcStartResponse { url }))
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OidcCallbackQuery {
|
|
code: String,
|
|
state: String,
|
|
}
|
|
|
|
async fn oidc_google_callback(
|
|
State(state): State<AppState>,
|
|
Query(q): Query<OidcCallbackQuery>,
|
|
) -> Result<Json<AuthResponse>, AuthnError> {
|
|
let client_id = env_or_file("GOOGLE_OIDC_CLIENT_ID", "GOOGLE_OIDC_CLIENT_ID_FILE")
|
|
.ok_or(AuthnError::OidcNotConfigured)?;
|
|
let client_secret = env_or_file(
|
|
"GOOGLE_OIDC_CLIENT_SECRET",
|
|
"GOOGLE_OIDC_CLIENT_SECRET_FILE",
|
|
)
|
|
.ok_or(AuthnError::OidcNotConfigured)?;
|
|
|
|
let state_key = oidc_state_key(&q.state);
|
|
let entry = state
|
|
.storage
|
|
.identities
|
|
.get(&state_key)
|
|
.await?
|
|
.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let state_record: Stored<OidcStateRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
let now_ms = unix_ms();
|
|
if now_ms >= state_record.data.expires_at_ms {
|
|
let _ = state.storage.identities.delete(&state_key).await;
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let _ = state.storage.identities.delete(&state_key).await;
|
|
|
|
#[derive(Deserialize)]
|
|
struct TokenResponse {
|
|
id_token: String,
|
|
}
|
|
|
|
let client = reqwest::Client::new();
|
|
let token_resp = client
|
|
.post("https://oauth2.googleapis.com/token")
|
|
.form(&[
|
|
("grant_type", "authorization_code"),
|
|
("code", q.code.as_str()),
|
|
("client_id", client_id.as_str()),
|
|
("client_secret", client_secret.as_str()),
|
|
("redirect_uri", state_record.data.redirect_uri.as_str()),
|
|
])
|
|
.send()
|
|
.await
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
if !token_resp.status().is_success() {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let token_body = token_resp
|
|
.json::<TokenResponse>()
|
|
.await
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
let claims =
|
|
verify_google_id_token(&token_body.id_token, &client_id, &state_record.data.nonce).await?;
|
|
|
|
let user_id = upsert_google_identity(&state.storage, &claims).await?;
|
|
|
|
let session = state
|
|
.storage
|
|
.create_refresh_session(&user_id, state.authn.refresh_ttl)
|
|
.await?;
|
|
let access_token = issue_access_token(&state.authn, &user_id, &session.session_id)?;
|
|
|
|
Ok(Json(AuthResponse {
|
|
access_token,
|
|
session_id: session.session_id,
|
|
refresh_token: session.refresh_token,
|
|
}))
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GoogleIdClaims {
|
|
sub: String,
|
|
email: Option<String>,
|
|
email_verified: Option<bool>,
|
|
nonce: Option<String>,
|
|
}
|
|
|
|
async fn verify_google_id_token(
|
|
id_token: &str,
|
|
client_id: &str,
|
|
expected_nonce: &str,
|
|
) -> Result<GoogleIdClaims, AuthnError> {
|
|
let header = jsonwebtoken::decode_header(id_token).map_err(|_| AuthnError::InvalidInput)?;
|
|
let kid = header.kid.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let jwks = reqwest::get("https://www.googleapis.com/oauth2/v3/certs")
|
|
.await
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?
|
|
.json::<serde_json::Value>()
|
|
.await
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
let keys = jwks
|
|
.get("keys")
|
|
.and_then(|v| v.as_array())
|
|
.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let mut n = None;
|
|
let mut e = None;
|
|
for k in keys {
|
|
if k.get("kid").and_then(|v| v.as_str()) == Some(kid.as_str()) {
|
|
n = k.get("n").and_then(|v| v.as_str()).map(|s| s.to_string());
|
|
e = k.get("e").and_then(|v| v.as_str()).map(|s| s.to_string());
|
|
break;
|
|
}
|
|
}
|
|
|
|
let n = n.ok_or(AuthnError::InvalidInput)?;
|
|
let e = e.ok_or(AuthnError::InvalidInput)?;
|
|
let decoding_key = jsonwebtoken::DecodingKey::from_rsa_components(&n, &e)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256);
|
|
validation.set_audience(&[client_id]);
|
|
validation.set_issuer(&["https://accounts.google.com", "accounts.google.com"]);
|
|
|
|
let token_data = jsonwebtoken::decode::<GoogleIdClaims>(id_token, &decoding_key, &validation)
|
|
.map_err(|_| AuthnError::InvalidInput)?;
|
|
|
|
let claims = token_data.claims;
|
|
if claims.nonce.as_deref() != Some(expected_nonce) {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
Ok(claims)
|
|
}
|
|
|
|
async fn upsert_google_identity(
|
|
storage: &GatewayStorage,
|
|
claims: &GoogleIdClaims,
|
|
) -> Result<String, AuthnError> {
|
|
let identity_key = format!("v1/identities/google/{}", claims.sub);
|
|
if let Some(entry) = storage.identities.get(&identity_key).await? {
|
|
let user_id = String::from_utf8(entry.value).map_err(|_| AuthnError::InvalidInput)?;
|
|
return Ok(user_id);
|
|
}
|
|
|
|
let email = claims.email.clone().ok_or(AuthnError::InvalidInput)?;
|
|
if claims.email_verified != Some(true) {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let email_hash = hash_stable("email", &email);
|
|
let existing = storage
|
|
.users
|
|
.get(&user_by_email_hash_key(&email_hash))
|
|
.await?
|
|
.and_then(|e| String::from_utf8(e.value).ok());
|
|
|
|
let user_id = if let Some(user_id) = existing {
|
|
user_id
|
|
} else {
|
|
let user_id = uuid::Uuid::new_v4().to_string();
|
|
let now_ms = unix_ms();
|
|
let user = UserRecord {
|
|
user_id: user_id.clone(),
|
|
email: email.clone(),
|
|
enabled: true,
|
|
created_at_ms: now_ms,
|
|
};
|
|
|
|
storage
|
|
.users
|
|
.create(
|
|
&user_by_email_hash_key(&email_hash),
|
|
user_id.as_bytes().to_vec(),
|
|
)
|
|
.await?;
|
|
storage
|
|
.users
|
|
.create(&user_key(&user_id), encode_stored(&user)?)
|
|
.await?;
|
|
|
|
user_id
|
|
};
|
|
|
|
let _ = storage
|
|
.identities
|
|
.create(&identity_key, user_id.as_bytes().to_vec())
|
|
.await;
|
|
|
|
Ok(user_id)
|
|
}
|
|
|
|
async fn mfa_enroll_start(
|
|
State(state): State<AppState>,
|
|
principal: crate::authz::Principal,
|
|
Json(req): Json<MfaEnrollStartRequest>,
|
|
) -> Result<Json<MfaEnrollStartResponse>, AuthnError> {
|
|
if req.user_id != principal.user_id {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let secret_base32 = generate_totp_secret_base32();
|
|
let now_ms = unix_ms();
|
|
let record = TotpEnrollmentRecord {
|
|
user_id: req.user_id.clone(),
|
|
secret_base32: secret_base32.clone(),
|
|
enabled: false,
|
|
created_at_ms: now_ms,
|
|
recovery_hashes: Vec::new(),
|
|
};
|
|
|
|
state
|
|
.storage
|
|
.mfa
|
|
.put(&totp_key(&principal.user_id), encode_stored(&record)?)
|
|
.await?;
|
|
|
|
Ok(Json(MfaEnrollStartResponse { secret_base32 }))
|
|
}
|
|
|
|
async fn mfa_enroll_confirm(
|
|
State(state): State<AppState>,
|
|
principal: crate::authz::Principal,
|
|
Json(req): Json<MfaEnrollConfirmRequest>,
|
|
) -> Result<Json<MfaEnrollConfirmResponse>, AuthnError> {
|
|
if req.user_id != principal.user_id {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let entry = state
|
|
.storage
|
|
.mfa
|
|
.get(&totp_key(&principal.user_id))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let mut stored: Stored<TotpEnrollmentRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
if stored.data.enabled {
|
|
return Ok(Json(MfaEnrollConfirmResponse {
|
|
status: "ok",
|
|
recovery_codes: Vec::new(),
|
|
}));
|
|
}
|
|
|
|
let now = SystemTime::now();
|
|
if !verify_totp_code(&stored.data.secret_base32, &req.code, now)? {
|
|
return Err(AuthnError::MfaInvalid);
|
|
}
|
|
|
|
stored.data.enabled = true;
|
|
let mut recovery_codes: Vec<String> = Vec::new();
|
|
if stored.data.recovery_hashes.is_empty() {
|
|
recovery_codes = generate_recovery_codes(10);
|
|
stored.data.recovery_hashes = recovery_codes
|
|
.iter()
|
|
.map(|c| hash_stable("recovery", c))
|
|
.collect();
|
|
}
|
|
let payload = serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
state
|
|
.storage
|
|
.mfa
|
|
.update(&totp_key(&principal.user_id), entry.revision, payload)
|
|
.await?;
|
|
|
|
Ok(Json(MfaEnrollConfirmResponse {
|
|
status: "ok",
|
|
recovery_codes,
|
|
}))
|
|
}
|
|
|
|
async fn mfa_challenge(
|
|
State(state): State<AppState>,
|
|
principal: crate::authz::Principal,
|
|
Json(req): Json<MfaChallengeRequest>,
|
|
) -> Result<Json<MfaChallengeResponse>, AuthnError> {
|
|
let key = totp_key(&principal.user_id);
|
|
|
|
for _ in 0..10 {
|
|
let entry = state
|
|
.storage
|
|
.mfa
|
|
.get(&key)
|
|
.await?
|
|
.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let mut stored: Stored<TotpEnrollmentRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
if !stored.data.enabled {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
|
|
let now = SystemTime::now();
|
|
if req.code.chars().all(|c| c.is_ascii_digit()) && req.code.len() == 6 {
|
|
if verify_totp_code(&stored.data.secret_base32, &req.code, now)? {
|
|
return Ok(Json(MfaChallengeResponse { status: "ok" }));
|
|
}
|
|
metrics::counter!("gateway_mfa_fail_total", "kind" => "totp").increment(1);
|
|
return Err(AuthnError::MfaInvalid);
|
|
}
|
|
|
|
let presented_hash = hash_stable("recovery", &req.code);
|
|
if let Some(pos) = stored
|
|
.data
|
|
.recovery_hashes
|
|
.iter()
|
|
.position(|h| h.as_bytes().ct_eq(presented_hash.as_bytes()).into())
|
|
{
|
|
stored.data.recovery_hashes.remove(pos);
|
|
let payload =
|
|
serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
match state
|
|
.storage
|
|
.mfa
|
|
.update(&key, entry.revision, payload)
|
|
.await
|
|
{
|
|
Ok(_) => {
|
|
return Ok(Json(MfaChallengeResponse { status: "ok" }));
|
|
}
|
|
Err(StorageError::CasMismatch) => continue,
|
|
Err(e) => return Err(e.into()),
|
|
}
|
|
}
|
|
|
|
metrics::counter!("gateway_mfa_fail_total", "kind" => "recovery").increment(1);
|
|
return Err(AuthnError::MfaInvalid);
|
|
}
|
|
|
|
Err(AuthnError::Storage("mfa cas failed".to_string()))
|
|
}
|
|
|
|
async fn session_user_id(storage: &GatewayStorage, session_id: &str) -> Result<String, AuthnError> {
|
|
let entry = storage
|
|
.refresh_sessions
|
|
.get(&format!("v1/sessions/{session_id}"))
|
|
.await?
|
|
.ok_or(AuthnError::InvalidRefresh)?;
|
|
|
|
let stored: Stored<crate::storage::RefreshSessionRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
Ok(stored.data.user_id)
|
|
}
|
|
|
|
async fn issue_password_reset_token(
|
|
storage: &GatewayStorage,
|
|
user_id: &str,
|
|
ttl: Duration,
|
|
) -> Result<String, AuthnError> {
|
|
let token = uuid::Uuid::new_v4().to_string();
|
|
let token_hash = hash_stable("reset", &token);
|
|
let now_ms = unix_ms();
|
|
let record = PasswordResetRecord {
|
|
user_id: user_id.to_string(),
|
|
token_hash: token_hash.clone(),
|
|
created_at_ms: now_ms,
|
|
expires_at_ms: now_ms + ttl.as_millis() as i64,
|
|
used_at_ms: None,
|
|
};
|
|
|
|
storage
|
|
.password_resets
|
|
.create(&reset_key(&token_hash), encode_stored(&record)?)
|
|
.await?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
async fn reset_password(
|
|
storage: &GatewayStorage,
|
|
presented_token: &str,
|
|
new_password: &str,
|
|
) -> Result<(), AuthnError> {
|
|
let token_hash = hash_stable("reset", presented_token);
|
|
let key = reset_key(&token_hash);
|
|
|
|
for _ in 0..10 {
|
|
let entry = storage
|
|
.password_resets
|
|
.get(&key)
|
|
.await?
|
|
.ok_or(AuthnError::InvalidReset)?;
|
|
|
|
let mut stored: Stored<PasswordResetRecord> =
|
|
serde_json::from_slice(&entry.value).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
let now_ms = unix_ms();
|
|
if stored.data.used_at_ms.is_some() {
|
|
return Err(AuthnError::ResetUsed);
|
|
}
|
|
if now_ms >= stored.data.expires_at_ms {
|
|
return Err(AuthnError::ResetExpired);
|
|
}
|
|
|
|
stored.data.used_at_ms = Some(now_ms);
|
|
let payload =
|
|
serde_json::to_vec(&stored).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
match storage
|
|
.password_resets
|
|
.update(&key, entry.revision, payload)
|
|
.await
|
|
{
|
|
Ok(_) => {
|
|
let password_hash = hash_password(new_password)?;
|
|
let cred_key = password_key(&stored.data.user_id);
|
|
let cred_entry = storage
|
|
.password_credentials
|
|
.get(&cred_key)
|
|
.await?
|
|
.ok_or(AuthnError::InvalidReset)?;
|
|
let mut stored_cred: Stored<PasswordCredentialRecord> =
|
|
serde_json::from_slice(&cred_entry.value)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
stored_cred.data.password_hash = password_hash;
|
|
stored_cred.data.updated_at_ms = now_ms;
|
|
let cred_payload = serde_json::to_vec(&stored_cred)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
storage
|
|
.password_credentials
|
|
.update(&cred_key, cred_entry.revision, cred_payload)
|
|
.await?;
|
|
|
|
revoke_all_refresh_sessions_for_user(storage, &stored.data.user_id).await?;
|
|
return Ok(());
|
|
}
|
|
Err(StorageError::CasMismatch) => continue,
|
|
Err(e) => return Err(e.into()),
|
|
}
|
|
}
|
|
|
|
Err(AuthnError::Storage("reset cas failed".to_string()))
|
|
}
|
|
|
|
async fn revoke_all_refresh_sessions_for_user(
|
|
storage: &GatewayStorage,
|
|
user_id: &str,
|
|
) -> Result<(), AuthnError> {
|
|
let keys = storage.refresh_sessions.list_keys("v1/sessions/").await?;
|
|
for key in keys {
|
|
if let Some(session_id) = key.strip_prefix("v1/sessions/") {
|
|
let entry = storage.refresh_sessions.get(&key).await?;
|
|
let Some(entry) = entry else {
|
|
continue;
|
|
};
|
|
let stored: Stored<crate::storage::RefreshSessionRecord> =
|
|
serde_json::from_slice(&entry.value)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
if stored.data.user_id == user_id {
|
|
let _ = storage.revoke_refresh_session(session_id).await;
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
|
pub struct AccessClaims {
|
|
pub sub: String,
|
|
pub session_id: String,
|
|
pub iat: i64,
|
|
pub exp: i64,
|
|
}
|
|
|
|
fn issue_access_token(
|
|
cfg: &AuthnConfig,
|
|
user_id: &str,
|
|
session_id: &str,
|
|
) -> Result<String, AuthnError> {
|
|
let now = Utc::now().timestamp();
|
|
let exp = now + cfg.access_ttl.as_secs() as i64;
|
|
let claims = AccessClaims {
|
|
sub: user_id.to_string(),
|
|
session_id: session_id.to_string(),
|
|
iat: now,
|
|
exp,
|
|
};
|
|
|
|
let token = jsonwebtoken::encode(
|
|
&jsonwebtoken::Header::default(),
|
|
&claims,
|
|
&jsonwebtoken::EncodingKey::from_secret(
|
|
cfg.jwt_secrets
|
|
.first()
|
|
.ok_or_else(|| AuthnError::Storage("no jwt secrets configured".to_string()))?,
|
|
),
|
|
)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
fn validate_password(password: &str) -> Result<(), AuthnError> {
|
|
if password.len() < 8 {
|
|
return Err(AuthnError::InvalidInput);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn normalize_email(email: &str) -> Option<String> {
|
|
let e = email.trim().to_ascii_lowercase();
|
|
if e.is_empty() || !e.contains('@') {
|
|
return None;
|
|
}
|
|
Some(e)
|
|
}
|
|
|
|
fn unix_ms() -> i64 {
|
|
SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_millis() as i64
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct StoredRef<'a, T> {
|
|
v: u32,
|
|
data: &'a T,
|
|
}
|
|
|
|
fn encode_stored<T: Serialize>(value: &T) -> Result<Vec<u8>, AuthnError> {
|
|
serde_json::to_vec(&StoredRef {
|
|
v: crate::storage::SCHEMA_VERSION,
|
|
data: value,
|
|
})
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))
|
|
}
|
|
|
|
fn hash_stable(domain: &str, value: &str) -> String {
|
|
let mut hasher = sha2::Sha256::new();
|
|
hasher.update(domain.as_bytes());
|
|
hasher.update([0u8]);
|
|
hasher.update(value.as_bytes());
|
|
hex::encode(hasher.finalize())
|
|
}
|
|
|
|
fn user_key(user_id: &str) -> String {
|
|
format!("v1/users/{user_id}")
|
|
}
|
|
|
|
fn user_by_email_hash_key(email_hash: &str) -> String {
|
|
format!("v1/users/by_email_hash/{email_hash}")
|
|
}
|
|
|
|
fn password_key(user_id: &str) -> String {
|
|
format!("v1/password/{user_id}")
|
|
}
|
|
|
|
fn reset_key(token_hash: &str) -> String {
|
|
format!("v1/resets/{token_hash}")
|
|
}
|
|
|
|
fn oidc_state_key(state: &str) -> String {
|
|
let h = hash_stable("oidc_state", state);
|
|
format!("v1/oidc/google/state/{h}")
|
|
}
|
|
|
|
fn service_token_key(user_id: &str) -> String {
|
|
format!("v1/service_tokens/{user_id}")
|
|
}
|
|
|
|
fn totp_key(user_id: &str) -> String {
|
|
format!("v1/totp/{user_id}")
|
|
}
|
|
|
|
fn hash_password(password: &str) -> Result<String, AuthnError> {
|
|
let salt = SaltString::generate(&mut rand_core::OsRng);
|
|
let argon2 = Argon2::default();
|
|
let hash = argon2
|
|
.hash_password(password.as_bytes(), &salt)
|
|
.map_err(|e| AuthnError::Storage(e.to_string()))?
|
|
.to_string();
|
|
Ok(hash)
|
|
}
|
|
|
|
fn verify_password(password: &str, hash: &str) -> Result<bool, AuthnError> {
|
|
let parsed = PasswordHash::new(hash).map_err(|e| AuthnError::Storage(e.to_string()))?;
|
|
let argon2 = Argon2::default();
|
|
Ok(argon2.verify_password(password.as_bytes(), &parsed).is_ok())
|
|
}
|
|
|
|
fn generate_totp_secret_base32() -> String {
|
|
let raw = uuid::Uuid::new_v4().as_bytes().to_vec();
|
|
base32::encode(base32::Alphabet::Rfc4648 { padding: false }, &raw)
|
|
}
|
|
|
|
fn generate_recovery_codes(count: usize) -> Vec<String> {
|
|
let mut out = Vec::with_capacity(count);
|
|
let mut seen = std::collections::HashSet::new();
|
|
while out.len() < count {
|
|
let code = uuid::Uuid::new_v4()
|
|
.simple()
|
|
.to_string()
|
|
.chars()
|
|
.take(8)
|
|
.collect::<String>()
|
|
.to_ascii_uppercase();
|
|
if seen.insert(code.clone()) {
|
|
out.push(code);
|
|
}
|
|
}
|
|
out
|
|
}
|
|
|
|
fn verify_totp_code(secret_base32: &str, code: &str, now: SystemTime) -> Result<bool, AuthnError> {
|
|
let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, secret_base32)
|
|
.ok_or(AuthnError::InvalidInput)?;
|
|
|
|
let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
|
|
let step = 30u64;
|
|
let counter = now_secs / step;
|
|
|
|
for offset in [0i64, -1, 1] {
|
|
let c = if offset.is_negative() {
|
|
counter.saturating_sub(offset.unsigned_abs())
|
|
} else {
|
|
counter.saturating_add(offset as u64)
|
|
};
|
|
let expected = totp_code(&secret, c, 6)?;
|
|
if expected.as_bytes().ct_eq(code.as_bytes()).into() {
|
|
return Ok(true);
|
|
}
|
|
}
|
|
|
|
Ok(false)
|
|
}
|
|
|
|
fn totp_code(secret: &[u8], counter: u64, digits: u32) -> Result<String, AuthnError> {
|
|
let mut msg = [0u8; 8];
|
|
msg.copy_from_slice(&counter.to_be_bytes());
|
|
|
|
let mut mac = Hmac::<Sha1>::new_from_slice(secret).map_err(|_| AuthnError::InvalidInput)?;
|
|
mac.update(&msg);
|
|
let result = mac.finalize().into_bytes();
|
|
|
|
let offset = (result[19] & 0x0f) as usize;
|
|
let binary = ((result[offset] as u32 & 0x7f) << 24)
|
|
| ((result[offset + 1] as u32) << 16)
|
|
| ((result[offset + 2] as u32) << 8)
|
|
| (result[offset + 3] as u32);
|
|
|
|
let modulus = 10u32.pow(digits);
|
|
let code = binary % modulus;
|
|
Ok(format!("{:0width$}", code, width = digits as usize))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use tower::util::ServiceExt;
|
|
|
|
#[test]
|
|
fn password_hashing_and_verification_works() {
|
|
let hash = hash_password("correct horse battery staple").unwrap();
|
|
assert!(hash.contains("$argon2id$"));
|
|
assert!(verify_password("correct horse battery staple", &hash).unwrap());
|
|
assert!(!verify_password("wrong", &hash).unwrap());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn refresh_rotation_invalidates_old_token() {
|
|
let metrics = crate::observability::init_metrics_for_tests();
|
|
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
|
crate::routing::RoutingConfig::empty(),
|
|
)))
|
|
.await
|
|
.unwrap();
|
|
let storage = crate::storage::GatewayStorage::new_in_memory();
|
|
let authn = AuthnConfig::for_tests();
|
|
let app = crate::app(crate::AppState {
|
|
metrics,
|
|
routing,
|
|
storage,
|
|
authn,
|
|
});
|
|
|
|
let signup_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/signup")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(
|
|
r#"{"email":"a@b.com","password":"password123"}"#,
|
|
))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(signup_res.status(), StatusCode::OK);
|
|
let body = axum::body::to_bytes(signup_res.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let created: AuthResponse = serde_json::from_slice(&body).unwrap();
|
|
|
|
let refresh_body = serde_json::to_vec(&RefreshRequest {
|
|
session_id: created.session_id.clone(),
|
|
refresh_token: created.refresh_token.clone(),
|
|
})
|
|
.unwrap();
|
|
|
|
let refresh_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/refresh")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(refresh_body))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(refresh_res.status(), StatusCode::OK);
|
|
let body = axum::body::to_bytes(refresh_res.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let refreshed: AuthResponse = serde_json::from_slice(&body).unwrap();
|
|
|
|
let refresh_again_body = serde_json::to_vec(&RefreshRequest {
|
|
session_id: created.session_id.clone(),
|
|
refresh_token: created.refresh_token.clone(),
|
|
})
|
|
.unwrap();
|
|
|
|
let refresh_again = app
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/refresh")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(refresh_again_body))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(refresh_again.status(), StatusCode::UNAUTHORIZED);
|
|
|
|
assert_ne!(refreshed.refresh_token, created.refresh_token);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn forgot_reset_token_is_one_time_and_expires() {
|
|
let storage = GatewayStorage::new_in_memory();
|
|
|
|
let user_id = "u1";
|
|
let token = issue_password_reset_token(&storage, user_id, Duration::from_millis(1))
|
|
.await
|
|
.unwrap();
|
|
tokio::time::sleep(Duration::from_millis(5)).await;
|
|
|
|
let res = reset_password(&storage, &token, "password123").await;
|
|
assert!(matches!(res, Err(AuthnError::ResetExpired)));
|
|
|
|
let token2 = issue_password_reset_token(&storage, user_id, Duration::from_secs(60))
|
|
.await
|
|
.unwrap();
|
|
|
|
let cred = PasswordCredentialRecord {
|
|
user_id: user_id.to_string(),
|
|
password_hash: hash_password("password123").unwrap(),
|
|
updated_at_ms: unix_ms(),
|
|
};
|
|
storage
|
|
.password_credentials
|
|
.create(
|
|
&password_key(user_id),
|
|
serde_json::to_vec(&Stored { v: 1, data: cred }).unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
|
|
reset_password(&storage, &token2, "newpassword123")
|
|
.await
|
|
.unwrap();
|
|
let again = reset_password(&storage, &token2, "anotherpassword123").await;
|
|
assert!(matches!(again, Err(AuthnError::ResetUsed)));
|
|
}
|
|
|
|
#[test]
|
|
fn totp_verification_accepts_valid_code() {
|
|
let secret = generate_totp_secret_base32();
|
|
let now = SystemTime::now();
|
|
let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
|
|
let counter = now_secs / 30;
|
|
let secret_bytes =
|
|
base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &secret).unwrap();
|
|
let code = totp_code(&secret_bytes, counter, 6).unwrap();
|
|
assert!(verify_totp_code(&secret, &code, now).unwrap());
|
|
assert!(!verify_totp_code(&secret, "000000", now).unwrap());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn mfa_enrollment_and_challenge_work_with_totp_and_recovery_codes() {
|
|
let metrics = crate::observability::init_metrics_for_tests();
|
|
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
|
crate::routing::RoutingConfig::empty(),
|
|
)))
|
|
.await
|
|
.unwrap();
|
|
let storage = crate::storage::GatewayStorage::new_in_memory();
|
|
let authn = AuthnConfig::for_tests();
|
|
let app = crate::app(crate::AppState {
|
|
metrics,
|
|
routing,
|
|
storage,
|
|
authn: authn.clone(),
|
|
});
|
|
|
|
let signup_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/signup")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(
|
|
r#"{"email":"mfa@b.com","password":"password123"}"#,
|
|
))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let body = axum::body::to_bytes(signup_res.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let created: AuthResponse = serde_json::from_slice(&body).unwrap();
|
|
let claims = authn.verify_access_token(&created.access_token).unwrap();
|
|
|
|
let enroll_start_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/mfa/enroll/start")
|
|
.header("authorization", format!("Bearer {}", created.access_token))
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(format!(
|
|
r#"{{"user_id":"{}"}}"#,
|
|
claims.sub
|
|
)))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(enroll_start_res.status(), StatusCode::OK);
|
|
let body = axum::body::to_bytes(enroll_start_res.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let enroll_start: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
let secret_base32 = enroll_start
|
|
.get("secret_base32")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap()
|
|
.to_string();
|
|
|
|
let now = SystemTime::now();
|
|
let now_secs = now.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
|
|
let counter = now_secs / 30;
|
|
let secret_bytes =
|
|
base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &secret_base32).unwrap();
|
|
let code = totp_code(&secret_bytes, counter, 6).unwrap();
|
|
|
|
let enroll_confirm_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/mfa/enroll/confirm")
|
|
.header("authorization", format!("Bearer {}", created.access_token))
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(format!(
|
|
r#"{{"user_id":"{}","code":"{}"}}"#,
|
|
claims.sub, code
|
|
)))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(enroll_confirm_res.status(), StatusCode::OK);
|
|
let body = axum::body::to_bytes(enroll_confirm_res.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let enroll_confirm: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
let recovery_codes: Vec<String> = enroll_confirm
|
|
.get("recovery_codes")
|
|
.and_then(|v| v.as_array())
|
|
.unwrap()
|
|
.iter()
|
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
.collect();
|
|
assert_eq!(recovery_codes.len(), 10);
|
|
|
|
let challenge_totp_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/mfa/challenge")
|
|
.header("authorization", format!("Bearer {}", created.access_token))
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(format!(r#"{{"code":"{}"}}"#, code)))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(challenge_totp_res.status(), StatusCode::OK);
|
|
|
|
let recovery = recovery_codes[0].clone();
|
|
let challenge_recovery_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/mfa/challenge")
|
|
.header("authorization", format!("Bearer {}", created.access_token))
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(format!(
|
|
r#"{{"code":"{}"}}"#,
|
|
recovery
|
|
)))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(challenge_recovery_res.status(), StatusCode::OK);
|
|
|
|
let challenge_recovery_again = app
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/mfa/challenge")
|
|
.header("authorization", format!("Bearer {}", created.access_token))
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(format!(
|
|
r#"{{"code":"{}"}}"#,
|
|
recovery
|
|
)))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(challenge_recovery_again.status(), StatusCode::BAD_REQUEST);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn rate_limits_trigger_for_signin_and_errors_do_not_echo_secrets() {
|
|
if let Some(lock) = RATE_LIMITER.get() {
|
|
lock.lock().await.clear();
|
|
}
|
|
|
|
let metrics = crate::observability::init_metrics_for_tests();
|
|
let routing = crate::routing::RouterState::new(Arc::new(crate::routing::FixedSource::new(
|
|
crate::routing::RoutingConfig::empty(),
|
|
)))
|
|
.await
|
|
.unwrap();
|
|
let storage = crate::storage::GatewayStorage::new_in_memory();
|
|
let authn = AuthnConfig::for_tests();
|
|
let app = crate::app(crate::AppState {
|
|
metrics,
|
|
routing,
|
|
storage,
|
|
authn,
|
|
});
|
|
|
|
let signup_res = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/signup")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(
|
|
r#"{"email":"rl@b.com","password":"password123"}"#,
|
|
))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(signup_res.status(), StatusCode::OK);
|
|
|
|
for i in 0..11 {
|
|
let resp = app
|
|
.clone()
|
|
.oneshot(
|
|
axum::http::Request::builder()
|
|
.method("POST")
|
|
.uri("/v1/auth/signin")
|
|
.header("content-type", "application/json")
|
|
.body(axum::body::Body::from(
|
|
r#"{"email":"rl@b.com","password":"supersecret"}"#,
|
|
))
|
|
.unwrap(),
|
|
)
|
|
.await
|
|
.unwrap();
|
|
let status = resp.status();
|
|
|
|
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
|
.await
|
|
.unwrap();
|
|
let body_str = String::from_utf8_lossy(&body);
|
|
assert!(!body_str.contains("supersecret"));
|
|
|
|
if i < 10 {
|
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
|
} else {
|
|
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
|
}
|
|
}
|
|
}
|
|
}
|