wip:milestone 0 fixes
Some checks failed
CI/CD Pipeline / unit-tests (push) Failing after 1m16s
CI/CD Pipeline / integration-tests (push) Failing after 2m32s
CI/CD Pipeline / lint (push) Successful in 5m22s
CI/CD Pipeline / e2e-tests (push) Has been skipped
CI/CD Pipeline / build (push) Has been skipped

This commit is contained in:
2026-03-15 12:35:42 +02:00
parent 6708cf28a7
commit cffdf8af86
61266 changed files with 4511646 additions and 1938 deletions

View File

@@ -1,433 +1,436 @@
use crate::middleware::AuthContext;
use crate::models::{
AuthResponse, RecoverRequest, SignInRequest, SignUpRequest, User, UserUpdateRequest,
VerifyRequest,
};
use crate::utils::{
generate_confirmation_token, generate_recovery_token, generate_refresh_token, generate_token,
hash_password, hash_refresh_token, issue_refresh_token, verify_password,
};
use axum::{
extract::{Extension, Query, State},
http::StatusCode,
Json,
};
use common::Config;
use common::ProjectContext;
use serde::Deserialize;
use serde_json::Value;
use sqlx::{Executor, PgPool, Postgres};
use std::collections::HashMap;
use uuid::Uuid;
use validator::Validate;
#[derive(Clone)]
pub struct AuthState {
pub db: PgPool,
pub config: Config,
}
#[derive(Deserialize)]
struct RefreshTokenGrant {
refresh_token: String,
}
pub async fn signup(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
project_ctx: Option<Extension<ProjectContext>>,
Json(payload): Json<SignUpRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
payload
.validate()
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
// Check if user exists
let user_exists = sqlx::query("SELECT id FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if user_exists.is_some() {
return Err((StatusCode::BAD_REQUEST, "User already exists".to_string()));
}
let hashed_password = hash_password(&payload.password)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let confirmation_token = generate_confirmation_token();
let user = sqlx::query_as::<_, User>(
r#"
INSERT INTO users (email, encrypted_password, raw_user_meta_data, confirmation_token, confirmed_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING *
"#,
)
.bind(&payload.email)
.bind(hashed_password)
.bind(payload.data.unwrap_or(serde_json::json!({})))
.bind(&confirmation_token)
.bind(None::<chrono::DateTime<chrono::Utc>>) // Initially unconfirmed? Or auto-confirmed for MVP?
// For now, let's keep auto-confirm logic if no email service, OR implement proper flow.
// The requirement is "Email Confirmation: Implement email verification flow".
// So we should NOT set confirmed_at yet.
.fetch_one(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Mock Email Sending
tracing::info!(
"Sending confirmation email to {}: token={}",
user.email,
confirmation_token
);
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
ctx.jwt_secret.as_str()
} else {
state.config.jwt_secret.as_str()
};
let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
Ok(Json(AuthResponse {
access_token: token,
token_type: "bearer".to_string(),
expires_in,
refresh_token,
user,
}))
}
pub async fn login(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
project_ctx: Option<Extension<ProjectContext>>,
Json(payload): Json<SignInRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE email = $1")
.bind(&payload.email)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::UNAUTHORIZED,
"Invalid email or password".to_string(),
))?;
if !verify_password(&payload.password, &user.encrypted_password)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
{
return Err((
StatusCode::UNAUTHORIZED,
"Invalid email or password".to_string(),
));
}
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
ctx.jwt_secret.as_str()
} else {
state.config.jwt_secret.as_str()
};
let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
Ok(Json(AuthResponse {
access_token: token,
token_type: "bearer".to_string(),
expires_in,
refresh_token,
user,
}))
}
pub async fn get_user(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
Extension(auth_ctx): Extension<AuthContext>,
) -> Result<Json<User>, (StatusCode, String)> {
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
let claims = auth_ctx
.claims
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?;
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
Ok(Json(user))
}
pub async fn token(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
project_ctx: Option<Extension<ProjectContext>>,
Query(params): Query<HashMap<String, String>>,
Json(payload): Json<Value>,
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
let grant_type = params
.get("grant_type")
.map(|s| s.as_str())
.unwrap_or("password");
match grant_type {
"password" => {
let req: SignInRequest = serde_json::from_value(payload)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
req.validate()
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
login(State(state), Some(Extension(db)), project_ctx, Json(req)).await
}
"refresh_token" => {
let req: RefreshTokenGrant = serde_json::from_value(payload)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
let token_hash = hash_refresh_token(&req.refresh_token);
let mut tx = db
.begin()
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let (revoked_token_hash, user_id, session_id) =
sqlx::query_as::<_, (String, Uuid, Option<Uuid>)>(
r#"
UPDATE refresh_tokens
SET revoked = true, updated_at = now()
WHERE token = $1 AND revoked = false
RETURNING token, user_id, session_id
"#,
)
.bind(&token_hash)
.fetch_optional(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::UNAUTHORIZED,
"Invalid refresh token".to_string(),
))?;
let session_id = session_id.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Missing session".to_string(),
))?;
let new_refresh_token =
issue_refresh_token(&mut *tx, user_id, session_id, Some(revoked_token_hash.as_str()))
.await?;
tx.commit()
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
ctx.jwt_secret.as_str()
} else {
state.config.jwt_secret.as_str()
};
let (access_token, expires_in, _) =
generate_token(user.id, &user.email, "authenticated", jwt_secret)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(AuthResponse {
access_token,
token_type: "bearer".to_string(),
expires_in,
refresh_token: new_refresh_token,
user,
}))
}
_ => Err((
StatusCode::BAD_REQUEST,
"Unsupported grant_type".to_string(),
)),
}
}
pub async fn recover(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
Json(payload): Json<RecoverRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
payload
.validate()
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
let token = generate_recovery_token();
let user = sqlx::query_as::<_, User>(
r#"
UPDATE users
SET recovery_token = $1
WHERE email = $2
RETURNING *
"#,
)
.bind(&token)
.bind(&payload.email)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// We don't want to leak whether the user exists or not, so we always return OK
if let Some(u) = user {
// Mock Email Sending
tracing::info!(
"Sending recovery email to {}: token={}",
u.email,
token
);
} else {
tracing::info!(
"Recovery requested for non-existent email: {}",
payload.email
);
}
Ok(Json(serde_json::json!({ "message": "If the email exists, a recovery link has been sent." })))
}
pub async fn verify(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
project_ctx: Option<Extension<ProjectContext>>,
Json(payload): Json<VerifyRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, String)> {
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
let user = match payload.r#type.as_str() {
"signup" => {
sqlx::query_as::<_, User>(
r#"
UPDATE users
SET email_confirmed_at = now(), confirmation_token = NULL
WHERE confirmation_token = $1
RETURNING *
"#,
)
.bind(&payload.token)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
}
"recovery" => {
sqlx::query_as::<_, User>(
r#"
UPDATE users
SET recovery_token = NULL
WHERE recovery_token = $1
RETURNING *
"#,
)
.bind(&payload.token)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
}
_ => return Err((StatusCode::BAD_REQUEST, "Unsupported verification type".to_string())),
};
let user = user.ok_or((StatusCode::BAD_REQUEST, "Invalid token".to_string()))?;
let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
ctx.jwt_secret.as_str()
} else {
state.config.jwt_secret.as_str()
};
let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
Ok(Json(AuthResponse {
access_token: token,
token_type: "bearer".to_string(),
expires_in,
refresh_token,
user,
}))
}
pub async fn update_user(
State(state): State<AuthState>,
db: Option<Extension<PgPool>>,
Extension(auth_ctx): Extension<AuthContext>,
Json(payload): Json<UserUpdateRequest>,
) -> Result<Json<User>, (StatusCode, String)> {
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
payload
.validate()
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
let claims = auth_ctx
.claims
.ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?;
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
let mut tx = db.begin().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if let Some(email) = &payload.email {
sqlx::query("UPDATE users SET email = $1 WHERE id = $2")
.bind(email)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
}
if let Some(password) = &payload.password {
let hashed = hash_password(password)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
sqlx::query("UPDATE users SET encrypted_password = $1 WHERE id = $2")
.bind(hashed)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
}
if let Some(data) = &payload.data {
sqlx::query("UPDATE users SET raw_user_meta_data = $1 WHERE id = $2")
.bind(data)
.bind(user_id)
.execute(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
}
// Commit the transaction first to ensure updates are visible
tx.commit().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Fetch the user after commit
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
.bind(user_id)
.fetch_optional(&db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
Ok(Json(user))
}
### /Users/vlad/Developer/madapes/madbase/auth/src/handlers.rs
```rust
1: use crate::middleware::AuthContext;
2: use crate::models::{
3: AuthResponse, RecoverRequest, SignInRequest, SignUpRequest, User, UserUpdateRequest,
4: VerifyRequest,
5: };
6: use crate::utils::{
7: generate_confirmation_token, generate_recovery_token, generate_refresh_token, generate_token,
8: hash_password, hash_refresh_token, issue_refresh_token, verify_password,
9: };
10: use axum::{
11: extract::{Extension, Query, State},
12: http::StatusCode,
13: Json,
14: };
15: use common::Config;
16: use common::ProjectContext;
17: use serde::Deserialize;
18: use serde_json::Value;
19: use sqlx::{Executor, PgPool, Postgres};
20: use std::collections::HashMap;
21: use uuid::Uuid;
22: use validator::Validate;
23:
24: #[derive(Clone)]
25: pub struct AuthState {
26: pub db: PgPool,
27: pub config: Config,
28: }
29:
30: #[derive(Deserialize)]
31: struct RefreshTokenGrant {
32: refresh_token: String,
33: }
34:
35: pub async fn signup(
36: State(state): State<AuthState>,
37: db: Option<Extension<PgPool>>,
38: project_ctx: Option<Extension<ProjectContext>>,
39: Json(payload): Json<SignUpRequest>,
40: ) -> Result<Json<AuthResponse>, (StatusCode, String)> {
41: payload
42: .validate()
43: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
44: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
45: // Check if user exists
46: let user_exists = sqlx::query("SELECT id FROM users WHERE email = $1")
47: .bind(&payload.email)
48: .fetch_optional(&db)
49: .await
50: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
51:
52: if user_exists.is_some() {
53: return Err((StatusCode::BAD_REQUEST, "User already exists".to_string()));
54: }
55:
56: let hashed_password = hash_password(&payload.password)
57: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
58:
59: let confirmation_token = generate_confirmation_token();
60:
61: let user = sqlx::query_as::<_, User>(
62: r#"
63: INSERT INTO users (email, encrypted_password, raw_user_meta_data, confirmation_token, confirmed_at)
64: VALUES ($1, $2, $3, $4, $5)
65: RETURNING *
66: "#,
67: )
68: .bind(&payload.email)
69: .bind(hashed_password)
70: .bind(payload.data.unwrap_or(serde_json::json!({})))
71: .bind(&confirmation_token)
72: .bind(None::<chrono::DateTime<chrono::Utc>>) // Initially unconfirmed? Or auto-confirmed for MVP?
73: // For now, let's keep auto-confirm logic if no email service, OR implement proper flow.
74: // The requirement is "Email Confirmation: Implement email verification flow".
75: // So we should NOT set confirmed_at yet.
76: .fetch_one(&db)
77: .await
78: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
79:
80: // Mock Email Sending
81: tracing::info!(
82: "Sending confirmation email to {}: token={}",
83: user.email,
84: confirmation_token
85: );
86:
87: let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
88: ctx.jwt_secret.as_str()
89: } else {
90: state.config.jwt_secret.as_str()
91: };
92:
93: let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
94: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
95:
96: let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
97: Ok(Json(AuthResponse {
98: access_token: token,
99: token_type: "bearer".to_string(),
100: expires_in,
101: refresh_token,
102: user,
103: }))
104: }
105:
106: pub async fn login(
107: State(state): State<AuthState>,
108: db: Option<Extension<PgPool>>,
109: project_ctx: Option<Extension<ProjectContext>>,
110: Json(payload): Json<SignInRequest>,
111: ) -> Result<Json<AuthResponse>, (StatusCode, String)> {
112: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
113: let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE email = $1")
114: .bind(&payload.email)
115: .fetch_optional(&db)
116: .await
117: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
118: .ok_or((
119: StatusCode::UNAUTHORIZED,
120: "Invalid email or password".to_string(),
121: ))?;
122:
123: if !verify_password(&payload.password, &user.encrypted_password)
124: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
125: {
126: return Err((
127: StatusCode::UNAUTHORIZED,
128: "Invalid email or password".to_string(),
129: ));
130: }
131:
132: let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
133: ctx.jwt_secret.as_str()
134: } else {
135: state.config.jwt_secret.as_str()
136: };
137:
138: let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
139: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
140:
141: let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
142: Ok(Json(AuthResponse {
143: access_token: token,
144: token_type: "bearer".to_string(),
145: expires_in,
146: refresh_token,
147: user,
148: }))
149: }
150:
151: pub async fn get_user(
152: State(state): State<AuthState>,
153: db: Option<Extension<PgPool>>,
154: Extension(auth_ctx): Extension<AuthContext>,
155: ) -> Result<Json<User>, (StatusCode, String)> {
156: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
157: let claims = auth_ctx
158: .claims
159: .ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?;
160:
161: let user_id = Uuid::parse_str(&claims.sub)
162: .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
163:
164: let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
165: .bind(user_id)
166: .fetch_optional(&db)
167: .await
168: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
169: .ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
170:
171: Ok(Json(user))
172: }
173:
174: pub async fn token(
175: State(state): State<AuthState>,
176: db: Option<Extension<PgPool>>,
177: project_ctx: Option<Extension<ProjectContext>>,
178: Query(params): Query<HashMap<String, String>>,
179: Json(payload): Json<Value>,
180: ) -> Result<Json<AuthResponse>, (StatusCode, String)> {
181: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
182: let grant_type = params
183: .get("grant_type")
184: .map(|s| s.as_str())
185: .unwrap_or("password");
186:
187: match grant_type {
188: "password" => {
189: let req: SignInRequest = serde_json::from_value(payload)
190: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
191: req.validate()
192: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
193: login(State(state), Some(Extension(db)), project_ctx, Json(req)).await
194: }
195: "refresh_token" => {
196: let req: RefreshTokenGrant = serde_json::from_value(payload)
197: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
198:
199: let token_hash = hash_refresh_token(&req.refresh_token);
200: let mut tx = db
201: .begin()
202: .await
203: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
204:
205: let (revoked_token_hash, user_id, session_id) =
206: sqlx::query_as::<_, (String, Uuid, Option<Uuid>)>(
207: r#"
208: UPDATE refresh_tokens
209: SET revoked = true, updated_at = now()
210: WHERE token = $1 AND revoked = false
211: RETURNING token, user_id, session_id
212: "#,
213: )
214: .bind(&token_hash)
215: .fetch_optional(&mut *tx)
216: .await
217: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
218: .ok_or((
219: StatusCode::UNAUTHORIZED,
220: "Invalid refresh token".to_string(),
221: ))?;
222:
223: let session_id = session_id.ok_or((
224: StatusCode::INTERNAL_SERVER_ERROR,
225: "Missing session".to_string(),
226: ))?;
227:
228: let new_refresh_token =
229: issue_refresh_token(&mut *tx, user_id, session_id, Some(revoked_token_hash.as_str()))
230: .await?;
231:
232: tx.commit()
233: .await
234: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
235:
236: let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
237: .bind(user_id)
238: .fetch_optional(&db)
239: .await
240: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
241: .ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
242:
243: let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
244: ctx.jwt_secret.as_str()
245: } else {
246: state.config.jwt_secret.as_str()
247: };
248:
249: let (access_token, expires_in, _) =
250: generate_token(user.id, &user.email, "authenticated", jwt_secret)
251: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
252:
253: Ok(Json(AuthResponse {
254: access_token,
255: token_type: "bearer".to_string(),
256: expires_in,
257: refresh_token: new_refresh_token,
258: user,
259: }))
260: }
261: _ => Err((
262: StatusCode::BAD_REQUEST,
263: "Unsupported grant_type".to_string(),
264: )),
265: }
266: }
267:
268: pub async fn recover(
269: State(state): State<AuthState>,
270: db: Option<Extension<PgPool>>,
271: Json(payload): Json<RecoverRequest>,
272: ) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
273: payload
274: .validate()
275: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
276: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
277:
278: let token = generate_recovery_token();
279:
280: let user = sqlx::query_as::<_, User>(
281: r#"
282: UPDATE users
283: SET recovery_token = $1
284: WHERE email = $2
285: RETURNING *
286: "#,
287: )
288: .bind(&token)
289: .bind(&payload.email)
290: .fetch_optional(&db)
291: .await
292: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
293:
294: // We don't want to leak whether the user exists or not, so we always return OK
295: if let Some(u) = user {
296: // Mock Email Sending
297: tracing::info!(
298: "Sending recovery email to {}: token={}",
299: u.email,
300: token
301: );
302: } else {
303: tracing::info!(
304: "Recovery requested for non-existent email: {}",
305: payload.email
306: );
307: }
308:
309: Ok(Json(serde_json::json!({ "message": "If the email exists, a recovery link has been sent." })))
310: }
311:
312: pub async fn verify(
313: State(state): State<AuthState>,
314: db: Option<Extension<PgPool>>,
315: project_ctx: Option<Extension<ProjectContext>>,
316: Json(payload): Json<VerifyRequest>,
317: ) -> Result<Json<AuthResponse>, (StatusCode, String)> {
318: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
319:
320: let user = match payload.r#type.as_str() {
321: "signup" => {
322: sqlx::query_as::<_, User>(
323: r#"
324: UPDATE users
325: SET email_confirmed_at = now(), confirmation_token = NULL
326: WHERE confirmation_token = $1
327: RETURNING *
328: "#,
329: )
330: .bind(&payload.token)
331: .fetch_optional(&db)
332: .await
333: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
334: }
335: "recovery" => {
336: sqlx::query_as::<_, User>(
337: r#"
338: UPDATE users
339: SET recovery_token = NULL
340: WHERE recovery_token = $1
341: RETURNING *
342: "#,
343: )
344: .bind(&payload.token)
345: .fetch_optional(&db)
346: .await
347: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
348: }
349: _ => return Err((StatusCode::BAD_REQUEST, "Unsupported verification type".to_string())),
350: };
351:
352: let user = user.ok_or((StatusCode::BAD_REQUEST, "Invalid token".to_string()))?;
353:
354: let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
355: ctx.jwt_secret.as_str()
356: } else {
357: state.config.jwt_secret.as_str()
358: };
359:
360: let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
361: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
362:
363: let refresh_token = issue_refresh_token(&db, user.id, Uuid::new_v4(), None).await?;
364: Ok(Json(AuthResponse {
365: access_token: token,
366: token_type: "bearer".to_string(),
367: expires_in,
368: refresh_token,
369: user,
370: }))
371: }
372:
373: pub async fn update_user(
374: State(state): State<AuthState>,
375: db: Option<Extension<PgPool>>,
376: Extension(auth_ctx): Extension<AuthContext>,
377: Json(payload): Json<UserUpdateRequest>,
378: ) -> Result<Json<User>, (StatusCode, String)> {
379: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
380: payload
381: .validate()
382: .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
383:
384: let claims = auth_ctx
385: .claims
386: .ok_or((StatusCode::UNAUTHORIZED, "Not authenticated".to_string()))?;
387: let user_id = Uuid::parse_str(&claims.sub)
388: .map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
389:
390: let mut tx = db.begin().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
391:
392: if let Some(email) = &payload.email {
393: sqlx::query("UPDATE users SET email = $1 WHERE id = $2")
394: .bind(email)
395: .bind(user_id)
396: .execute(&mut *tx)
397: .await
398: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
399: }
400:
401: if let Some(password) = &payload.password {
402: let hashed = hash_password(password)
403: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
404: sqlx::query("UPDATE users SET encrypted_password = $1 WHERE id = $2")
405: .bind(hashed)
406: .bind(user_id)
407: .execute(&mut *tx)
408: .await
409: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
410: }
411:
412: if let Some(data) = &payload.data {
413: sqlx::query("UPDATE users SET raw_user_meta_data = $1 WHERE id = $2")
414: .bind(data)
415: .bind(user_id)
416: .execute(&mut *tx)
417: .await
418: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
419: }
420:
421: // Commit the transaction first to ensure updates are visible
422: tx.commit().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
423:
424: // Fetch the user after commit
425: let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
426: .bind(user_id)
427: .fetch_optional(&db)
428: .await
429: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
430: .ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
431:
432: Ok(Json(user))
433: }
```

14
auth/src/session.rs Normal file
View File

@@ -0,0 +1,14 @@

199
auth/src/session.rs.bak Normal file
View File

@@ -0,0 +1,199 @@
```rust
1: //! Distributed session management using Redis
2: //!
3: //! This module provides session storage that works across multiple proxy nodes.
4: //! Sessions are stored in Redis and can be accessed by any proxy instance.
5:
6: use common::{CacheLayer, CacheError, CacheResult, SessionData};
7: use uuid::Uuid;
8: use chrono::{DateTime, Utc, Duration};
9:
10: /// Session manager for distributed auth sessions
11: #[derive(Clone)]
12: pub struct SessionManager {
13: cache: CacheLayer,
14: session_ttl: u64, // Session TTL in seconds
15: }
16:
17: impl SessionManager {
18: /// Create a new session manager
19: pub fn new(cache: CacheLayer, session_ttl: u64) -> Self {
20: Self { cache, session_ttl }
21: }
22:
23: /// Create a new session for a user
24: pub async fn create_session(
25: &self,
26: user_id: Uuid,
27: email: String,
28: role: String,
29: ) -> CacheResult<String> {
30: let session_token = Uuid::new_v4().to_string();
31: let now = Utc::now();
32: let expires_at = now + Duration::seconds(self.session_ttl as i64);
33:
34: let session = SessionData {
35: user_id,
36: email,
37: role,
38: created_at: now,
39: expires_at,
40: };
41:
42: // Store session in Redis
43: let key = format!("session:{}", session_token);
44: self.cache.set(&key, &session).await?;
45:
46: // Also add to user's active sessions set (for multi-device logout)
47: let user_sessions_key = format!("user:{}:sessions", user_id);
48: if let Some(redis) = &self.cache.redis {
49: let mut conn = redis.get_async_connection().await?;
50: redis::cmd("SADD")
51: .arg(&user_sessions_key)
52: .arg(&session_token)
53: .query_async(&mut conn)
54: .await?;
55:
56: // Set expiration on the set
57: redis::cmd("EXPIRE")
58: .arg(&user_sessions_key)
59: .arg(self.session_ttl * 2)
60: .query_async(&mut conn)
61: .await?;
62: }
63:
64: Ok(session_token)
65: }
66:
67: /// Get a session by token
68: pub async fn get_session(&self, session_token: &str) -> CacheResult<Option<SessionData>> {
69: self.cache.get_session(session_token.to_string()).await
70: }
71:
72: /// Validate a session (check if it exists and is not expired)
73: pub async fn validate_session(&self, session_token: &str) -> CacheResult<Option<SessionData>> {
74: let session = self.get_session(session_token).await?;
75:
76: if let Some(session) = session {
77: let now = Utc::now();
78: if now < session.expires_at {
79: return Ok(Some(session));
80: }
81: }
82:
83: Ok(None)
84: }
85:
86: /// Refresh a session (extend expiration)
87: pub async fn refresh_session(&self, session_token: &str) -> CacheResult<bool> {
88: if let Some(mut session) = self.get_session(session_token).await? {
89: let now = Utc::now();
90: session.expires_at = now + Duration::seconds(self.session_ttl as i64);
91:
92: let key = format!("session:{}", session_token);
93: self.cache.set(&key, &session).await?;
94: return Ok(true);
95: }
96:
97: Ok(false)
98: }
99:
100: /// Delete a session (logout)
101: pub async fn delete_session(&self, session_token: &str) -> CacheResult<()> {
102: // Get the session first to remove from user's session set
103: if let Some(session) = self.get_session(session_token).await? {
104: let user_sessions_key = format!("user:{}:sessions", session.user_id);
105:
106: if let Some(redis) = &self.cache.redis {
107: let mut conn = redis.get_async_connection().await?;
108: redis::cmd("SREM")
109: .arg(&user_sessions_key)
110: .arg(session_token)
111: .query_async(&mut conn)
112: .await?;
113: }
114: }
115:
116: self.cache.delete_session(session_token.to_string()).await
117: }
118:
119: /// Delete all sessions for a user (logout from all devices)
120: pub async fn delete_all_user_sessions(&self, user_id: Uuid) -> CacheResult<usize> {
121: let user_sessions_key = format!("user:{}:sessions", user_id);
122:
123: if let Some(redis) = &self.cache.redis {
124: let mut conn = redis.get_async_connection().await?;
125:
126: // Get all session tokens for this user
127: let session_tokens: Vec<String> = redis::cmd("SMEMBERS")
128: .arg(&user_sessions_key)
129: .query_async(&mut conn)
130: .await?;
131:
132: let count = session_tokens.len();
133:
134: // Delete each session
135: for token in &session_tokens {
136: let session_key = format!("session:{}", token);
137: redis::cmd("DEL")
138: .arg(&session_key)
139: .query_async(&mut conn)
140: .await?;
141: }
142:
143: // Delete the user's session set
144: redis::cmd("DEL")
145: .arg(&user_sessions_key)
146: .query_async(&mut conn)
147: .await?;
148:
149: Ok(count)
150: } else {
151: Ok(0)
152: }
153: }
154:
155: /// Get all active sessions for a user
156: pub async fn get_user_sessions(&self, user_id: Uuid) -> CacheResult<Vec<SessionData>> {
157: let user_sessions_key = format!("user:{}:sessions", user_id);
158:
159: if let Some(redis) = &self.cache.redis {
160: let mut conn = redis.get_async_connection().await?;
161:
162: let session_tokens: Vec<String> = redis::cmd("SMEMBERS")
163: .arg(&user_sessions_key)
164: .query_async(&mut conn)
165: .await?;
166:
167: let mut sessions = Vec::new();
168: for token in session_tokens {
169: if let Some(session) = self.get_session(&token).await? {
170: sessions.push(session);
171: }
172: }
173:
174: Ok(sessions)
175: } else {
176: Ok(vec![])
177: }
178: }
179:
180: /// Count active sessions for a user
181: pub async fn get_user_session_count(&self, user_id: Uuid) -> CacheResult<usize> {
182: let sessions = self.get_user_sessions(user_id).await?;
183: Ok(sessions.len())
184: }
185: }
186:
187: #[cfg(test)]
188: mod tests {
189: use super::*;
190:
191: #[tokio::test]
192: async fn test_session_manager_creation() {
193: let cache = CacheLayer::new(None, 3600);
194: let manager = SessionManager::new(cache, 3600);
195: assert_eq!(manager.session_ttl, 3600);
196: }
197: }
```

231
auth/src/sso.rs.bak Normal file
View File

@@ -0,0 +1,231 @@
```rust
1: use crate::utils::{generate_token, issue_refresh_token};
2: use crate::AuthState;
3: use axum::{
4: extract::{Path, Query, State},
5: http::StatusCode,
6: response::{IntoResponse, Redirect},
7: Json,
8: Extension,
9: };
10: use common::ProjectContext;
11: use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType};
12: use openidconnect::{
13: AuthenticationFlow, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, RedirectUrl, Scope, TokenResponse
14: };
15: use serde::Deserialize;
16: use serde_json::json;
17: use sqlx::Row;
18: use uuid::Uuid;
19:
20: // In-memory cache for OIDC clients to avoid rediscovery on every request
21: // Key: domain, Value: CoreClient
22:
23: #[derive(Deserialize)]
24: pub struct SsoRequest {
25: pub domain: Option<String>,
26: pub provider_id: Option<Uuid>,
27: pub redirect_to: Option<String>,
28: }
29:
30: #[derive(Deserialize)]
31: pub struct SsoCallback {
32: pub code: String,
33: pub state: String,
34: pub nonce: String, // We need to pass nonce via state or separate param usually
35: }
36:
37: pub async fn sso_authorize(
38: State(state): State<AuthState>,
39: Json(payload): Json<SsoRequest>,
40: ) -> Result<impl IntoResponse, (StatusCode, String)> {
41: // 1. Find Provider
42: let row = if let Some(domain) = &payload.domain {
43: sqlx::query("SELECT * FROM auth.sso_providers WHERE domain = $1")
44: .bind(domain)
45: .fetch_optional(&state.db)
46: .await
47: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
48: } else if let Some(id) = payload.provider_id {
49: sqlx::query("SELECT * FROM auth.sso_providers WHERE id = $1")
50: .bind(id)
51: .fetch_optional(&state.db)
52: .await
53: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
54: } else {
55: return Err((StatusCode::BAD_REQUEST, "Either domain or provider_id required".to_string()));
56: };
57:
58: let provider = row.ok_or((StatusCode::NOT_FOUND, "SSO Provider not found".to_string()))?;
59:
60: let issuer_url: String = provider.get("oidc_issuer_url");
61: let client_id: String = provider.get("oidc_client_id");
62: let client_secret: String = provider.get("oidc_client_secret");
63: let domain: String = provider.get("domain");
64:
65: // 2. Discover Metadata (Ideally cached)
66: let provider_metadata = CoreProviderMetadata::discover_async(
67: IssuerUrl::new(issuer_url).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?,
68: openidconnect::reqwest::async_http_client,
69: )
70: .await
71: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Discovery failed: {}", e)))?;
72:
73: // 3. Create Client
74: let client = CoreClient::from_provider_metadata(
75: provider_metadata,
76: ClientId::new(client_id),
77: Some(ClientSecret::new(client_secret)),
78: )
79: .set_redirect_uri(
80: RedirectUrl::new(format!("{}/sso/callback/{}", state.config.redirect_uri.trim_end_matches("/auth/v1/callback"), domain))
81: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?,
82: );
83:
84: // 4. Generate URL
85: let (authorize_url, csrf_state, nonce) = client
86: .authorize_url(
87: AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
88: CsrfToken::new_random,
89: Nonce::new_random,
90: )
91: .add_scope(Scope::new("email".to_string()))
92: .add_scope(Scope::new("profile".to_string()))
93: .url();
94:
95: // TODO: Store csrf_state and nonce securely (e.g. Redis or secure cookie)
96: // For MVP, we might encode them in the state param or rely on stateless verification if possible (less secure)
97: // Here we assume the client handles the redirection.
98:
99: Ok(Json(json!({
100: "url": authorize_url.to_string(),
101: "state": csrf_state.secret(),
102: "nonce": nonce.secret()
103: })))
104: }
105:
106: // NOTE: This callback logic assumes the client (browser) followed the link and is now returning.
107: // Since we don't have session state here to verify CSRF/Nonce (stateless API),
108: // a real implementation would typically use a signed cookie or a separate "initiate" step that sets a cookie.
109: // For this MVP, we will verify the code exchange but skip strict state/nonce validation against a server-side store,
110: // which is a SECURITY RISK in production but acceptable for a "skeleton" implementation.
111:
112: pub async fn sso_callback(
113: State(state): State<AuthState>,
114: db: Option<Extension<sqlx::PgPool>>,
115: project_ctx: Option<Extension<ProjectContext>>,
116: Path(domain): Path<String>,
117: Query(query): Query<SsoCallback>,
118: ) -> Result<impl IntoResponse, (StatusCode, String)> {
119: let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
120:
121: // 1. Fetch Provider
122: let provider = sqlx::query("SELECT * FROM auth.sso_providers WHERE domain = $1")
123: .bind(&domain)
124: .fetch_optional(&db)
125: .await
126: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
127: .ok_or((StatusCode::NOT_FOUND, "Provider not found".to_string()))?;
128:
129: let issuer_url: String = provider.get("oidc_issuer_url");
130: let client_id: String = provider.get("oidc_client_id");
131: let client_secret: String = provider.get("oidc_client_secret");
132:
133: // 2. Setup Client
134: let provider_metadata = CoreProviderMetadata::discover_async(
135: IssuerUrl::new(issuer_url.clone()).unwrap(),
136: openidconnect::reqwest::async_http_client,
137: )
138: .await
139: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Discovery failed: {}", e)))?;
140:
141: let client = CoreClient::from_provider_metadata(
142: provider_metadata,
143: ClientId::new(client_id),
144: Some(ClientSecret::new(client_secret)),
145: )
146: .set_redirect_uri(
147: RedirectUrl::new(format!("{}/sso/callback/{}", state.config.redirect_uri.trim_end_matches("/auth/v1/callback"), domain)).unwrap(),
148: );
149:
150: // 3. Exchange Code
151: let token_response = client
152: .exchange_code(openidconnect::AuthorizationCode::new(query.code))
153: .request_async(openidconnect::reqwest::async_http_client)
154: .await
155: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Token exchange failed: {}", e)))?;
156:
157: // 4. Get ID Token & Claims
158: let id_token = token_response.id_token()
159: .ok_or((StatusCode::INTERNAL_SERVER_ERROR, "No ID Token received".to_string()))?;
160:
161: let claims = id_token.claims(
162: &client.id_token_verifier(),
163: &Nonce::new(query.nonce), // We trust the user provided nonce for now (Insecure MVP)
164: ).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Claims verification failed: {}", e)))?;
165:
166: let email = claims.email().ok_or((StatusCode::BAD_REQUEST, "Email not found in claims".to_string()))?.as_str();
167: let name = claims.name().and_then(|n| n.get(None)).map(|n| n.as_str().to_string());
168: let picture = claims.picture().and_then(|p| p.get(None)).map(|p| p.as_str().to_string());
169: let sub = claims.subject().as_str();
170:
171: // 5. Create/Update User
172: let existing_user = sqlx::query_as::<_, crate::models::User>("SELECT * FROM users WHERE email = $1")
173: .bind(email)
174: .fetch_optional(&db)
175: .await
176: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
177:
178: let user = if let Some(u) = existing_user {
179: u
180: } else {
181: let raw_meta = json!({
182: "name": name,
183: "avatar_url": picture,
184: "provider": "sso",
185: "provider_id": sub,
186: "iss": issuer_url
187: });
188:
189: sqlx::query_as::<_, crate::models::User>(
190: r#"
191: INSERT INTO users (email, encrypted_password, raw_user_meta_data)
192: VALUES ($1, $2, $3)
193: RETURNING *
194: "#,
195: )
196: .bind(email)
197: .bind("sso_user_no_password")
198: .bind(raw_meta)
199: .fetch_one(&db)
200: .await
201: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
202: };
203:
204: // 6. Issue Token
205: let jwt_secret = if let Some(Extension(ctx)) = project_ctx.as_ref() {
206: ctx.jwt_secret.as_str()
207: } else {
208: state.config.jwt_secret.as_str()
209: };
210:
211: let (token, expires_in, _) = generate_token(user.id, &user.email, "authenticated", jwt_secret)
212: .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
213:
214: let refresh_token: String = issue_refresh_token(&db, user.id, Uuid::new_v4(), None)
215: .await
216: .map_err(|(code, msg)| (StatusCode::from_u16(code.as_u16()).unwrap(), msg))?;
217:
218: // Redirect to frontend with tokens
219: // Ideally we redirect to a frontend callback URL with hash params
220: let redirect_url = format!(
221: "{}/auth/callback?access_token={}&refresh_token={}&expires_in={}&type=bearer",
222: state.config.redirect_uri.trim_end_matches("/auth/v1/callback"), // Base URL assumption
223: token,
224: refresh_token,
225: expires_in
226: );
227:
228: Ok(Redirect::to(&redirect_url))
229: }
```