use crate::error::ApiError; use sqlx::{PgPool, Postgres, Transaction}; const ALLOWED_ROLES: &[&str] = &["anon", "authenticated", "service_role"]; pub struct RlsTransaction { pub tx: Transaction<'static, Postgres>, } impl RlsTransaction { /// Begin a transaction with RLS context set. /// `role` must be one of: anon, authenticated, service_role. /// `sub` is the JWT subject claim (user ID), used for RLS policies. pub async fn begin( pool: &PgPool, role: &str, sub: Option<&str>, ) -> Result { let mut tx = pool.begin().await?; // Validate and set role if !ALLOWED_ROLES.contains(&role) { return Err(ApiError::Forbidden("Invalid role".into())); } let role_query = format!("SET LOCAL role = '{}'", role); sqlx::query(&role_query).execute(&mut *tx).await?; // Set JWT claims for RLS policies if let Some(sub) = sub { sqlx::query("SELECT set_config('request.jwt.claim.sub', $1, true)") .bind(sub) .execute(&mut *tx) .await?; } Ok(Self { tx }) } pub async fn commit(self) -> Result<(), ApiError> { self.tx.commit().await.map_err(ApiError::from) } } impl std::ops::Deref for RlsTransaction { type Target = Transaction<'static, Postgres>; fn deref(&self) -> &Self::Target { &self.tx } } impl std::ops::DerefMut for RlsTransaction { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.tx } } #[cfg(test)] mod tests { use super::*; #[test] fn test_rls_transaction_rejects_bad_role() { // Verify role validation without needing a DB connection assert!(ALLOWED_ROLES.contains(&"anon")); assert!(ALLOWED_ROLES.contains(&"authenticated")); assert!(ALLOWED_ROLES.contains(&"service_role")); assert!(!ALLOWED_ROLES.contains(&"admin")); assert!(!ALLOWED_ROLES.contains(&"superuser")); assert!(!ALLOWED_ROLES.contains(&"'; DROP TABLE users; --")); } #[tokio::test] #[ignore] // Requires running PostgreSQL — run with: cargo test -- --ignored async fn test_rls_transaction_sets_role() { let pool = PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres") .await .expect("DB connection required"); let mut rls = RlsTransaction::begin(&pool, "authenticated", Some("user-123")).await.unwrap(); let row: (String,) = sqlx::query_as("SELECT current_setting('role')") .fetch_one(&mut *rls.tx) .await .unwrap(); assert_eq!(row.0, "authenticated"); } #[tokio::test] #[ignore] // Requires running PostgreSQL — run with: cargo test -- --ignored async fn test_rls_transaction_sets_claims() { let pool = PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres") .await .expect("DB connection required"); let mut rls = RlsTransaction::begin(&pool, "authenticated", Some("user-abc-123")).await.unwrap(); let row: (String,) = sqlx::query_as("SELECT current_setting('request.jwt.claim.sub')") .fetch_one(&mut *rls.tx) .await .unwrap(); assert_eq!(row.0, "user-abc-123"); } }