added initial roadmap and implementation
This commit is contained in:
249
auth/src/handlers.rs
Normal file
249
auth/src/handlers.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
use crate::middleware::AuthContext;
|
||||
use crate::models::{AuthResponse, SignInRequest, SignUpRequest, User};
|
||||
use crate::utils::{
|
||||
generate_refresh_token, generate_token, hash_password, hash_refresh_token, issue_refresh_token, verify_password,
|
||||
};
|
||||
use axum::{
|
||||
extract::{Extension, Query, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use common::Config;
|
||||
use common::ProjectContext;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use sqlx::{Executor, PgPool, Postgres};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
use validator::Validate;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
pub db: PgPool,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RefreshTokenGrant {
|
||||
refresh_token: String,
|
||||
}
|
||||
|
||||
pub async fn signup(
|
||||
State(state): State<AuthState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
project_ctx: Option<Extension<ProjectContext>>,
|
||||
Json(payload): Json<SignUpRequest>,
|
||||
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
|
||||
payload.validate().map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
// Check if user exists
|
||||
let user_exists = sqlx::query("SELECT id FROM users WHERE email = $1")
|
||||
.bind(&payload.email)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
if user_exists.is_some() {
|
||||
return Err((StatusCode::BAD_REQUEST, "User already exists".to_string()));
|
||||
}
|
||||
|
||||
let hashed_password = hash_password(&payload.password)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
r#"
|
||||
INSERT INTO users (email, encrypted_password, raw_user_meta_data)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&payload.email)
|
||||
.bind(hashed_password)
|
||||
.bind(payload.data.unwrap_or(serde_json::json!({})))
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
|
||||
ctx.jwt_secret.as_str()
|
||||
} else {
|
||||
state.config.jwt_secret.as_str()
|
||||
};
|
||||
|
||||
let (token, expires_in) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
|
||||
Ok(Json(AuthResponse {
|
||||
access_token: token,
|
||||
token_type: "bearer".to_string(),
|
||||
expires_in,
|
||||
refresh_token,
|
||||
user,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(state): State<AuthState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
project_ctx: Option<Extension<ProjectContext>>,
|
||||
Json(payload): Json<SignInRequest>,
|
||||
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE email = $1")
|
||||
.bind(&payload.email)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid email or password".to_string(),
|
||||
))?;
|
||||
|
||||
if !verify_password(&payload.password, &user.encrypted_password)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
{
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid email or password".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
|
||||
ctx.jwt_secret.as_str()
|
||||
} else {
|
||||
state.config.jwt_secret.as_str()
|
||||
};
|
||||
|
||||
let (token, expires_in) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
|
||||
Ok(Json(AuthResponse {
|
||||
access_token: token,
|
||||
token_type: "bearer".to_string(),
|
||||
expires_in,
|
||||
refresh_token,
|
||||
user,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn get_user(
|
||||
State(state): State<AuthState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
) -> Result<Json<User>, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let claims = auth_ctx
|
||||
.claims
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?;
|
||||
|
||||
let user_id = Uuid::parse_str(&claims.sub)
|
||||
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
|
||||
|
||||
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
|
||||
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
pub async fn token(
|
||||
State(state): State<AuthState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
project_ctx: Option<Extension<ProjectContext>>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let grant_type = params
|
||||
.get("grant_type")
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("password");
|
||||
|
||||
match grant_type {
|
||||
"password" => {
|
||||
let req: SignInRequest = serde_json::from_value(payload)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
req.validate().map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
login(State(state), Some(Extension(db)), project_ctx, Json(req)).await
|
||||
}
|
||||
"refresh_token" => {
|
||||
let req: RefreshTokenGrant = serde_json::from_value(payload)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
|
||||
let token_hash = hash_refresh_token(&req.refresh_token);
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let (revoked_token_hash, user_id, session_id) =
|
||||
sqlx::query_as::<_, (String, Uuid, Option<Uuid>)>(
|
||||
r#"
|
||||
UPDATE refresh_tokens
|
||||
SET revoked = true, updated_at = now()
|
||||
WHERE token = $1 AND revoked = false
|
||||
RETURNING token, user_id, session_id
|
||||
"#,
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid refresh token".to_string(),
|
||||
))?;
|
||||
|
||||
let session_id = session_id.ok_or((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Missing session".to_string(),
|
||||
))?;
|
||||
|
||||
let new_refresh_token = issue_refresh_token(
|
||||
&mut *tx,
|
||||
user_id,
|
||||
session_id,
|
||||
Some(revoked_token_hash.as_str()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
|
||||
|
||||
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
|
||||
ctx.jwt_secret.as_str()
|
||||
} else {
|
||||
state.config.jwt_secret.as_str()
|
||||
};
|
||||
|
||||
let (access_token, expires_in) =
|
||||
generate_token(user.id, &user.email, "authenticated", jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
access_token,
|
||||
token_type: "bearer".to_string(),
|
||||
expires_in,
|
||||
refresh_token: new_refresh_token,
|
||||
user,
|
||||
}))
|
||||
}
|
||||
_ => Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Unsupported grant_type".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
19
auth/src/lib.rs
Normal file
19
auth/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
pub mod handlers;
|
||||
pub mod middleware;
|
||||
pub mod models;
|
||||
pub mod oauth;
|
||||
pub mod utils;
|
||||
|
||||
use axum::routing::{get, post};
|
||||
pub use axum::Router;
|
||||
pub use handlers::AuthState;
|
||||
pub use middleware::{auth_middleware, AuthContext, AuthMiddlewareState};
|
||||
|
||||
pub fn router() -> Router<AuthState> {
|
||||
Router::new()
|
||||
.route("/signup", post(handlers::signup))
|
||||
.route("/token", post(handlers::token))
|
||||
.route("/authorize", get(oauth::authorize))
|
||||
.route("/callback/:provider", get(oauth::callback))
|
||||
.route("/user", get(handlers::get_user))
|
||||
}
|
||||
122
auth/src/middleware.rs
Normal file
122
auth/src/middleware.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use common::{Config, ProjectContext};
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthMiddlewareState {
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub email: Option<String>,
|
||||
pub role: String,
|
||||
pub exp: usize,
|
||||
pub iss: String,
|
||||
pub aud: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthContext {
|
||||
pub claims: Option<Claims>,
|
||||
pub role: String,
|
||||
}
|
||||
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<AuthMiddlewareState>,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// 1. Try to get ProjectContext (if available)
|
||||
// If we are running in multi-tenant mode, ProjectContext should be present.
|
||||
// If not, we fall back to global config (legacy/single-tenant).
|
||||
let project_ctx = req.extensions().get::<ProjectContext>().cloned();
|
||||
|
||||
// Allow public OAuth routes
|
||||
let path = req.uri().path();
|
||||
if path.contains("/authorize") || path.contains("/callback") {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
// Determine the secret to use
|
||||
let jwt_secret = if let Some(ctx) = &project_ctx {
|
||||
ctx.jwt_secret.clone()
|
||||
} else {
|
||||
state.config.jwt_secret.clone()
|
||||
};
|
||||
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let apikey_header = req
|
||||
.headers()
|
||||
.get("apikey")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Logic:
|
||||
// 1. Bearer Token takes precedence for identity (Claims).
|
||||
// 2. API Key is checked if no Bearer token, OR it acts as the "Client Key" (anon/service).
|
||||
// Usually Supabase requires 'apikey' header ALWAYS, and Authorization header OPTIONAL (for user context).
|
||||
|
||||
let token = if let Some(auth) = auth_header {
|
||||
auth.strip_prefix("Bearer ").map(|t| t.to_string())
|
||||
} else {
|
||||
// If no Auth header, check apikey header as fallback (e.g. for anon requests)
|
||||
apikey_header.clone()
|
||||
};
|
||||
|
||||
if let Some(token) = token {
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
// validation.set_audience(&["authenticated"]); // If we used audience
|
||||
|
||||
match decode::<Claims>(
|
||||
&token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_bytes()),
|
||||
&validation,
|
||||
) {
|
||||
Ok(token_data) => {
|
||||
let claims = token_data.claims;
|
||||
let role = claims.role.clone();
|
||||
|
||||
let ctx = AuthContext {
|
||||
claims: Some(claims),
|
||||
role,
|
||||
};
|
||||
req.extensions_mut().insert(ctx);
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
Err(_) => {
|
||||
// Invalid token
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No valid token found.
|
||||
// Assign "anon" role if apikey is valid anon key?
|
||||
// Or just default to "anon" role without claims?
|
||||
// Supabase usually requires a valid JWT even for anon. The 'anon key' IS a JWT with role='anon'.
|
||||
|
||||
// So if decoding failed above, we returned Unauthorized.
|
||||
// If no header provided at all?
|
||||
// We should allow public routes to proceed?
|
||||
// But this middleware is applied to ALL routes in /rest, /auth etc.
|
||||
// /auth/v1/signup needs to be accessible.
|
||||
// But wait, even signup requires the 'anon' key in Supabase.
|
||||
|
||||
// So: strict check.
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
64
auth/src/models.rs
Normal file
64
auth/src/models.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use uuid::Uuid;
|
||||
use validator::Validate;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, FromRow, Clone)]
|
||||
pub struct User {
|
||||
pub id: Uuid,
|
||||
pub email: String,
|
||||
#[serde(skip)]
|
||||
pub encrypted_password: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub last_sign_in_at: Option<DateTime<Utc>>,
|
||||
pub raw_app_meta_data: serde_json::Value,
|
||||
pub raw_user_meta_data: serde_json::Value,
|
||||
pub is_super_admin: Option<bool>,
|
||||
pub confirmed_at: Option<DateTime<Utc>>,
|
||||
pub email_confirmed_at: Option<DateTime<Utc>>,
|
||||
pub phone: Option<String>,
|
||||
pub phone_confirmed_at: Option<DateTime<Utc>>,
|
||||
pub confirmation_token: Option<String>,
|
||||
pub recovery_token: Option<String>,
|
||||
pub email_change_token_new: Option<String>,
|
||||
pub email_change: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Validate)]
|
||||
pub struct SignUpRequest {
|
||||
#[validate(email)]
|
||||
pub email: String,
|
||||
#[validate(length(min = 6, message = "Password must be at least 6 characters"))]
|
||||
pub password: String,
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Validate)]
|
||||
pub struct SignInRequest {
|
||||
#[validate(email)]
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuthResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: i64,
|
||||
pub refresh_token: String,
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, FromRow)]
|
||||
pub struct RefreshToken {
|
||||
pub id: i64, // BigSerial
|
||||
pub token: String,
|
||||
pub user_id: Uuid,
|
||||
pub revoked: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub parent: Option<String>,
|
||||
pub session_id: Option<Uuid>,
|
||||
}
|
||||
307
auth/src/oauth.rs
Normal file
307
auth/src/oauth.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use crate::utils::{generate_token, issue_refresh_token};
|
||||
use crate::AuthState;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Redirect},
|
||||
Json,
|
||||
extract::Extension,
|
||||
};
|
||||
use common::{Config, ProjectContext};
|
||||
use oauth2::{
|
||||
basic::{BasicErrorResponseType, BasicTokenType},
|
||||
AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
|
||||
EmptyExtraTokenFields, EndpointNotSet, EndpointSet, HttpRequest, HttpResponse,
|
||||
RedirectUrl, RevocationErrorResponseType, Scope, StandardErrorResponse,
|
||||
StandardRevocableToken, StandardTokenIntrospectionResponse, StandardTokenResponse,
|
||||
TokenResponse, TokenUrl,
|
||||
};
|
||||
use reqwest::Client as ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OAuthRequest {
|
||||
pub provider: String,
|
||||
pub redirect_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OAuthCallback {
|
||||
pub code: String,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UserProfile {
|
||||
email: String,
|
||||
name: Option<String>,
|
||||
avatar_url: Option<String>,
|
||||
provider_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OAuthHttpError(String);
|
||||
impl std::fmt::Display for OAuthHttpError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "OAuth HTTP Error: {}", self.0)
|
||||
}
|
||||
}
|
||||
impl std::error::Error for OAuthHttpError {}
|
||||
|
||||
// Define the client type that matches our usage (AuthUrl + TokenUrl set)
|
||||
type OAuthClient = Client<
|
||||
StandardErrorResponse<BasicErrorResponseType>,
|
||||
StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>,
|
||||
StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>,
|
||||
StandardRevocableToken,
|
||||
StandardErrorResponse<RevocationErrorResponseType>,
|
||||
EndpointSet, // HasAuthUrl
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointNotSet,
|
||||
EndpointSet, // HasTokenUrl
|
||||
>;
|
||||
|
||||
pub async fn async_http_client(
|
||||
request: HttpRequest,
|
||||
) -> Result<HttpResponse, OAuthHttpError> {
|
||||
let client = reqwest::Client::builder()
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.build()
|
||||
.map_err(|e| OAuthHttpError(e.to_string()))?;
|
||||
|
||||
let mut request_builder = client
|
||||
.request(request.method().clone(), request.uri().to_string());
|
||||
|
||||
for (name, value) in request.headers() {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
request_builder = request_builder.body(request.into_body());
|
||||
|
||||
let response = request_builder.send().await.map_err(|e| OAuthHttpError(e.to_string()))?;
|
||||
|
||||
let mut builder = axum::http::Response::builder()
|
||||
.status(response.status());
|
||||
|
||||
for (name, value) in response.headers() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
|
||||
builder
|
||||
.body(response.bytes().await.map_err(|e| OAuthHttpError(e.to_string()))?.to_vec())
|
||||
.map_err(|e| OAuthHttpError(e.to_string()))
|
||||
}
|
||||
|
||||
fn get_client(provider: &str, config: &Config) -> Result<OAuthClient, String> {
|
||||
let (client_id, client_secret, auth_url, token_url) = match provider {
|
||||
"google" => (
|
||||
config.google_client_id.clone().ok_or("Google Client ID not set")?,
|
||||
config.google_client_secret.clone().ok_or("Google Client Secret not set")?,
|
||||
"https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://oauth2.googleapis.com/token",
|
||||
),
|
||||
"github" => (
|
||||
config.github_client_id.clone().ok_or("GitHub Client ID not set")?,
|
||||
config.github_client_secret.clone().ok_or("GitHub Client Secret not set")?,
|
||||
"https://github.com/login/oauth/authorize",
|
||||
"https://github.com/login/oauth/access_token",
|
||||
),
|
||||
_ => return Err(format!("Unknown provider: {}", provider)),
|
||||
};
|
||||
|
||||
let redirect_uri = if config.redirect_uri.ends_with('/') {
|
||||
format!("{}{}", config.redirect_uri, provider)
|
||||
} else {
|
||||
format!("{}/{}", config.redirect_uri, provider)
|
||||
};
|
||||
|
||||
let client = Client::new(ClientId::new(client_id))
|
||||
.set_client_secret(ClientSecret::new(client_secret))
|
||||
.set_auth_uri(AuthUrl::new(auth_url.to_string()).map_err(|e| e.to_string())?)
|
||||
.set_token_uri(TokenUrl::new(token_url.to_string()).map_err(|e| e.to_string())?)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_uri).map_err(|e| e.to_string())?);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn authorize(
|
||||
State(state): State<AuthState>,
|
||||
Query(query): Query<OAuthRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let client = get_client(&query.provider, &state.config)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e))?;
|
||||
|
||||
let mut auth_request = client.authorize_url(CsrfToken::new_random);
|
||||
|
||||
match query.provider.as_str() {
|
||||
"google" => {
|
||||
auth_request = auth_request
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.add_scope(Scope::new("profile".to_string()));
|
||||
}
|
||||
"github" => {
|
||||
auth_request = auth_request
|
||||
.add_scope(Scope::new("user:email".to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let (auth_url, _csrf_token) = auth_request.url();
|
||||
|
||||
// TODO: Store csrf_token in cookie/session for validation
|
||||
|
||||
Ok(Redirect::to(auth_url.as_str()))
|
||||
}
|
||||
|
||||
pub async fn callback(
|
||||
State(state): State<AuthState>,
|
||||
db: Option<Extension<sqlx::PgPool>>,
|
||||
project_ctx: Option<Extension<ProjectContext>>,
|
||||
Path(provider): Path<String>,
|
||||
Query(query): Query<OAuthCallback>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let client = get_client(&provider, &state.config)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e))?;
|
||||
|
||||
let token_result = client
|
||||
.exchange_code(AuthorizationCode::new(query.code))
|
||||
.request_async(&async_http_client)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Token exchange failed: {}", e)))?;
|
||||
|
||||
let access_token = token_result.access_token().secret();
|
||||
|
||||
let user_profile = fetch_user_profile(&provider, access_token).await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
// Check if user exists by email
|
||||
let existing_user = sqlx::query_as::<_, crate::models::User>("SELECT * FROM users WHERE email = $1")
|
||||
.bind(&user_profile.email)
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let user = if let Some(u) = existing_user {
|
||||
// Update user meta data if needed? For now, just return existing user.
|
||||
// We might want to record that they logged in with this provider.
|
||||
u
|
||||
} else {
|
||||
// Create new user
|
||||
let raw_meta = json!({
|
||||
"name": user_profile.name,
|
||||
"avatar_url": user_profile.avatar_url,
|
||||
"provider": provider,
|
||||
"provider_id": user_profile.provider_id
|
||||
});
|
||||
|
||||
sqlx::query_as::<_, crate::models::User>(
|
||||
r#"
|
||||
INSERT INTO users (email, encrypted_password, raw_user_meta_data)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&user_profile.email)
|
||||
.bind("oauth_user_no_password") // Placeholder
|
||||
.bind(raw_meta)
|
||||
.fetch_one(&db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
};
|
||||
|
||||
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
|
||||
ctx.jwt_secret.as_str()
|
||||
} else {
|
||||
state.config.jwt_secret.as_str()
|
||||
};
|
||||
|
||||
let (token, expires_in) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let refresh_token: String = issue_refresh_token(&db, user.id, Uuid::new_v4(), None)
|
||||
.await
|
||||
.map_err(|(code, msg)| (StatusCode::from_u16(code.as_u16()).unwrap(), msg))?;
|
||||
|
||||
Ok(Json(json!({
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": expires_in,
|
||||
"refresh_token": refresh_token,
|
||||
"user": user
|
||||
})))
|
||||
}
|
||||
|
||||
async fn fetch_user_profile(provider: &str, token: &str) -> Result<UserProfile, String> {
|
||||
let client = ReqwestClient::new();
|
||||
match provider {
|
||||
"google" => {
|
||||
let resp = client.get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let email = resp["email"].as_str().ok_or("No email found")?.to_string();
|
||||
let name = resp["name"].as_str().map(|s| s.to_string());
|
||||
let avatar_url = resp["picture"].as_str().map(|s| s.to_string());
|
||||
let provider_id = resp["id"].as_str().ok_or("No ID found")?.to_string();
|
||||
|
||||
Ok(UserProfile {
|
||||
email,
|
||||
name,
|
||||
avatar_url,
|
||||
provider_id,
|
||||
})
|
||||
},
|
||||
"github" => {
|
||||
let resp = client.get("https://api.github.com/user")
|
||||
.bearer_auth(token)
|
||||
.header("User-Agent", "madbase")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.json::<Value>()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let email = if let Some(e) = resp["email"].as_str() {
|
||||
e.to_string()
|
||||
} else {
|
||||
// Fetch private emails
|
||||
let emails = client.get("https://api.github.com/user/emails")
|
||||
.bearer_auth(token)
|
||||
.header("User-Agent", "madbase")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?
|
||||
.json::<Vec<Value>>()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let primary = emails.iter().find(|e| e["primary"].as_bool().unwrap_or(false))
|
||||
.ok_or("No primary email found")?;
|
||||
|
||||
primary["email"].as_str().ok_or("No email found")?.to_string()
|
||||
};
|
||||
|
||||
let name = resp["name"].as_str().map(|s| s.to_string());
|
||||
let avatar_url = resp["avatar_url"].as_str().map(|s| s.to_string());
|
||||
let provider_id = resp["id"].as_i64().map(|id| id.to_string()).ok_or("No ID found")?.to_string();
|
||||
|
||||
Ok(UserProfile {
|
||||
email,
|
||||
name,
|
||||
avatar_url,
|
||||
provider_id,
|
||||
})
|
||||
},
|
||||
_ => Err("Unknown provider".to_string())
|
||||
}
|
||||
}
|
||||
118
auth/src/utils.rs
Normal file
118
auth/src/utils.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use argon2::{
|
||||
password_hash::{
|
||||
rand_core::{OsRng, RngCore},
|
||||
PasswordHash, PasswordHasher, PasswordVerifier, SaltString,
|
||||
},
|
||||
Argon2,
|
||||
};
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub email: Option<String>,
|
||||
pub role: String,
|
||||
pub exp: usize,
|
||||
pub iss: String,
|
||||
pub aud: Option<String>,
|
||||
pub iat: usize,
|
||||
}
|
||||
|
||||
pub fn hash_password(password: &str) -> anyhow::Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| anyhow::anyhow!(e))?
|
||||
.to_string();
|
||||
Ok(password_hash)
|
||||
}
|
||||
|
||||
pub fn verify_password(password: &str, password_hash: &str) -> anyhow::Result<bool> {
|
||||
let parsed_hash = PasswordHash::new(password_hash).map_err(|e| anyhow::anyhow!(e))?;
|
||||
Ok(Argon2::default()
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
pub fn generate_refresh_token() -> String {
|
||||
let mut bytes = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut bytes);
|
||||
hex_encode(&bytes)
|
||||
}
|
||||
|
||||
pub fn hash_refresh_token(raw: &str) -> String {
|
||||
let digest = Sha256::digest(raw.as_bytes());
|
||||
hex_encode(&digest)
|
||||
}
|
||||
|
||||
pub fn generate_token(
|
||||
user_id: Uuid,
|
||||
email: &str,
|
||||
role: &str,
|
||||
jwt_secret: &str,
|
||||
) -> anyhow::Result<(String, i64)> {
|
||||
let now = Utc::now();
|
||||
let expiration = now
|
||||
.checked_add_signed(Duration::seconds(3600)) // 1 hour
|
||||
.expect("valid timestamp")
|
||||
.timestamp();
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
email: Some(email.to_string()),
|
||||
role: role.to_string(),
|
||||
exp: expiration as usize,
|
||||
iss: "madbase".to_string(),
|
||||
aud: Some("authenticated".to_string()),
|
||||
iat: now.timestamp() as usize,
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(jwt_secret.as_bytes()),
|
||||
)?;
|
||||
|
||||
Ok((token, 3600))
|
||||
}
|
||||
|
||||
fn hex_encode(bytes: &[u8]) -> String {
|
||||
let mut out = String::with_capacity(bytes.len() * 2);
|
||||
for b in bytes {
|
||||
use std::fmt::Write;
|
||||
let _ = write!(&mut out, "{:02x}", b);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn issue_refresh_token(
|
||||
executor: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
|
||||
user_id: Uuid,
|
||||
session_id: Uuid,
|
||||
parent: Option<&str>,
|
||||
) -> Result<String, (axum::http::StatusCode, String)> {
|
||||
let token = generate_refresh_token();
|
||||
let token_hash = hash_refresh_token(&token);
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO refresh_tokens (token, user_id, session_id, parent)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
"#,
|
||||
)
|
||||
.bind(&token_hash)
|
||||
.bind(user_id)
|
||||
.bind(session_id)
|
||||
.bind(parent)
|
||||
.execute(executor)
|
||||
.await
|
||||
.map_err(|e| (axum::http::StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user