added initial roadmap and implementation
This commit is contained in:
893
data_api/src/handlers.rs
Normal file
893
data_api/src/handlers.rs
Normal file
@@ -0,0 +1,893 @@
|
||||
use crate::parser::{Operator, QueryParams, SelectNode, FilterNode};
|
||||
use auth::AuthContext;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json},
|
||||
Extension,
|
||||
};
|
||||
use common::Config;
|
||||
use futures::future::BoxFuture;
|
||||
use serde_json::{json, Value};
|
||||
use sqlx::{Column, PgPool, Row, TypeInfo};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DataState {
|
||||
pub db: PgPool,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
enum SqlValue {
|
||||
String(String),
|
||||
Int(i64),
|
||||
Float(f64),
|
||||
Bool(bool),
|
||||
Uuid(Uuid),
|
||||
Json(Value),
|
||||
Null,
|
||||
}
|
||||
|
||||
fn json_value_to_sql_value(v: Value) -> SqlValue {
|
||||
match v {
|
||||
Value::String(s) => {
|
||||
if let Ok(u) = Uuid::parse_str(&s) {
|
||||
SqlValue::Uuid(u)
|
||||
} else {
|
||||
SqlValue::String(s)
|
||||
}
|
||||
},
|
||||
Value::Number(n) => {
|
||||
if let Some(i) = n.as_i64() {
|
||||
SqlValue::Int(i)
|
||||
} else if let Some(f) = n.as_f64() {
|
||||
SqlValue::Float(f)
|
||||
} else {
|
||||
SqlValue::String(n.to_string())
|
||||
}
|
||||
},
|
||||
Value::Bool(b) => SqlValue::Bool(b),
|
||||
Value::Object(_) | Value::Array(_) => SqlValue::Json(v),
|
||||
Value::Null => SqlValue::Null,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let query_params = QueryParams::parse(params);
|
||||
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
|
||||
// Start transaction for RLS
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Set RLS variables
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set role: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(claims) = &auth_ctx.claims {
|
||||
let sub_query = "SELECT set_config('request.jwt.claim.sub', $1, true)";
|
||||
sqlx::query(sub_query)
|
||||
.bind(&claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(email) = &claims.email {
|
||||
let email_query = "SELECT set_config('request.jwt.claim.email', $1, true)";
|
||||
sqlx::query(email_query)
|
||||
.bind(email)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Construct Query ---
|
||||
// Use pool for schema introspection to avoid borrowing tx
|
||||
let select_clause = build_select_clause(&query_params.select, &table, &db).await?;
|
||||
|
||||
let mut sql = format!("SELECT {} FROM {}", select_clause, table);
|
||||
let mut values: Vec<SqlValue> = Vec::new();
|
||||
let mut param_index = 1;
|
||||
|
||||
if !query_params.filters.is_empty() {
|
||||
sql.push_str(" WHERE ");
|
||||
let conditions: Vec<String> = query_params
|
||||
.filters
|
||||
.iter()
|
||||
.map(|f| build_filter_clause(f, &mut param_index, &mut values))
|
||||
.collect();
|
||||
sql.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
|
||||
if let Some(order) = query_params.order {
|
||||
if is_valid_identifier(&order.column) {
|
||||
let dir = match order.direction {
|
||||
crate::parser::Direction::Asc => "ASC",
|
||||
crate::parser::Direction::Desc => "DESC",
|
||||
};
|
||||
sql.push_str(&format!(" ORDER BY {} {}", order.column, dir));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(limit) = query_params.limit {
|
||||
sql.push_str(&format!(" LIMIT {}", limit));
|
||||
}
|
||||
|
||||
if let Some(offset) = query_params.offset {
|
||||
sql.push_str(&format!(" OFFSET {}", offset));
|
||||
}
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
for v in values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
SqlValue::Int(n) => query = query.bind(n),
|
||||
SqlValue::Float(f) => query = query.bind(f),
|
||||
SqlValue::Bool(b) => query = query.bind(b),
|
||||
SqlValue::Uuid(u) => query = query.bind(u),
|
||||
SqlValue::Json(j) => query = query.bind(j),
|
||||
SqlValue::Null => query = query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
let rows = query
|
||||
.fetch_all(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let json_rows = rows_to_json(rows);
|
||||
Ok(Json(json_rows))
|
||||
}
|
||||
|
||||
fn build_filter_clause(
|
||||
node: &FilterNode,
|
||||
param_index: &mut usize,
|
||||
values: &mut Vec<SqlValue>,
|
||||
) -> String {
|
||||
match node {
|
||||
FilterNode::Condition { column, operator, value } => {
|
||||
if !is_valid_identifier(column) {
|
||||
return "false".to_string();
|
||||
}
|
||||
let clause = match operator {
|
||||
Operator::In => {
|
||||
format!("{} {} (${})", column, operator.to_sql(), param_index)
|
||||
}
|
||||
_ => format!("{} {} ${}", column, operator.to_sql(), param_index),
|
||||
};
|
||||
|
||||
let val = if let Ok(i) = value.parse::<i64>() {
|
||||
SqlValue::Int(i)
|
||||
} else if let Ok(f) = value.parse::<f64>() {
|
||||
SqlValue::Float(f)
|
||||
} else if let Ok(b) = value.parse::<bool>() {
|
||||
SqlValue::Bool(b)
|
||||
} else if let Ok(u) = Uuid::parse_str(value) {
|
||||
SqlValue::Uuid(u)
|
||||
} else {
|
||||
SqlValue::String(value.clone())
|
||||
};
|
||||
|
||||
values.push(val);
|
||||
*param_index += 1;
|
||||
clause
|
||||
}
|
||||
FilterNode::Or(nodes) => {
|
||||
let clauses: Vec<String> = nodes
|
||||
.iter()
|
||||
.map(|n| build_filter_clause(n, param_index, values))
|
||||
.collect();
|
||||
if clauses.is_empty() {
|
||||
"false".to_string()
|
||||
} else {
|
||||
format!("({})", clauses.join(" OR "))
|
||||
}
|
||||
}
|
||||
FilterNode::And(nodes) => {
|
||||
let clauses: Vec<String> = nodes
|
||||
.iter()
|
||||
.map(|n| build_filter_clause(n, param_index, values))
|
||||
.collect();
|
||||
if clauses.is_empty() {
|
||||
"true".to_string()
|
||||
} else {
|
||||
format!("({})", clauses.join(" AND "))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn build_select_clause<'a>(
|
||||
nodes: &'a [SelectNode],
|
||||
table: &'a str,
|
||||
pool: &'a PgPool,
|
||||
) -> BoxFuture<'a, Result<String, (StatusCode, String)>> {
|
||||
Box::pin(async move {
|
||||
if nodes.is_empty() {
|
||||
return Ok("*".to_string());
|
||||
}
|
||||
|
||||
let mut clauses = Vec::new();
|
||||
for node in nodes {
|
||||
match node {
|
||||
SelectNode::Column(c) => {
|
||||
if c == "*" {
|
||||
clauses.push("*".to_string());
|
||||
} else if is_valid_identifier(c) {
|
||||
clauses.push(format!("\"{}\"", c));
|
||||
}
|
||||
}
|
||||
SelectNode::Relation(rel, inner) => {
|
||||
let fk_info = find_foreign_key(table, rel, pool)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
if let Some((local_col, foreign_table, foreign_col)) = fk_info {
|
||||
let inner_select = if inner.is_empty() {
|
||||
"*".to_string()
|
||||
} else {
|
||||
build_select_clause(inner, &foreign_table, pool).await?
|
||||
};
|
||||
|
||||
let subquery = if foreign_col.starts_with("REV:") {
|
||||
let actual_foreign_col = &foreign_col[4..];
|
||||
format!(
|
||||
"(SELECT json_agg(t) FROM (SELECT {} FROM {} WHERE {} = {}.{}) t) as \"{}\"",
|
||||
inner_select, foreign_table, actual_foreign_col, table, local_col, rel
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"(SELECT row_to_json(t) FROM (SELECT {} FROM {} WHERE {} = {}.{}) t) as \"{}\"",
|
||||
inner_select, foreign_table, foreign_col, table, local_col, rel
|
||||
)
|
||||
};
|
||||
clauses.push(subquery);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if clauses.is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "No valid columns selected".to_string()));
|
||||
}
|
||||
|
||||
Ok(clauses.join(", "))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
async fn find_foreign_key(
|
||||
table: &str,
|
||||
relation: &str,
|
||||
pool: &PgPool,
|
||||
) -> Result<Option<(String, String, String)>, String> {
|
||||
// Basic introspection to find FK.
|
||||
// We look for a table named `relation` or a column named `relation_id`.
|
||||
// PostgREST logic is complex, here's a simplified version:
|
||||
// 1. Check if `relation` is a table name.
|
||||
// 2. Find FK between `table` and `relation`.
|
||||
|
||||
let query = r#"
|
||||
SELECT
|
||||
kcu.column_name as local_col,
|
||||
ccu.table_name as foreign_table,
|
||||
ccu.column_name as foreign_col
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = $1
|
||||
AND ccu.table_name = $2;
|
||||
"#;
|
||||
|
||||
let row = sqlx::query_as::<_, (String, String, String)>(query)
|
||||
.bind(table)
|
||||
.bind(relation)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if let Some(r) = row {
|
||||
return Ok(Some(r));
|
||||
}
|
||||
|
||||
// Try reverse (many-to-one): relation table has FK to our table
|
||||
let reverse_query = r#"
|
||||
SELECT
|
||||
ccu.column_name as local_col,
|
||||
tc.table_name as foreign_table,
|
||||
kcu.column_name as foreign_col
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_name = $2
|
||||
AND ccu.table_name = $1;
|
||||
"#;
|
||||
|
||||
let row = sqlx::query_as::<_, (String, String, String)>(reverse_query)
|
||||
.bind(table)
|
||||
.bind(relation)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if let Some(r) = row {
|
||||
// For reverse relations (one-to-many), we want to aggregate them.
|
||||
// Returning a tuple that signifies reverse relation might be tricky with the same signature.
|
||||
// Let's hack it: return foreign_col as "REV:foreign_col".
|
||||
return Ok(Some((r.0, r.1, format!("REV:{}", r.2))));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
|
||||
fn rows_to_json(rows: Vec<sqlx::postgres::PgRow>) -> Vec<Value> {
|
||||
let mut json_rows = Vec::new();
|
||||
for row in rows {
|
||||
let mut obj = serde_json::Map::new();
|
||||
for col in row.columns() {
|
||||
let name = col.name();
|
||||
let type_info = col.type_info();
|
||||
let type_name = type_info.name();
|
||||
|
||||
tracing::info!("Column: {}, Type: {}", name, type_name);
|
||||
|
||||
let val: Value = if type_name == "BOOL" {
|
||||
json!(row.try_get::<bool, _>(name).unwrap_or(false))
|
||||
} else if type_name == "INT2" {
|
||||
json!(row.try_get::<i16, _>(name).unwrap_or(0))
|
||||
} else if type_name == "INT4" {
|
||||
json!(row.try_get::<i32, _>(name).unwrap_or(0))
|
||||
} else if type_name == "INT8" {
|
||||
json!(row.try_get::<i64, _>(name).unwrap_or(0))
|
||||
} else if ["FLOAT4", "FLOAT8"].contains(&type_name) {
|
||||
json!(row.try_get::<f64, _>(name).unwrap_or(0.0))
|
||||
} else if ["JSON", "JSONB"].contains(&type_name) {
|
||||
row.try_get::<Value, _>(name).unwrap_or(Value::Null)
|
||||
} else if type_name == "UUID" {
|
||||
if let Ok(u) = row.try_get::<Uuid, _>(name) {
|
||||
json!(u.to_string())
|
||||
} else {
|
||||
Value::Null
|
||||
}
|
||||
} else if type_name == "TIMESTAMPTZ" {
|
||||
if let Ok(ts) = row.try_get::<chrono::DateTime<chrono::Utc>, _>(name) {
|
||||
json!(ts)
|
||||
} else {
|
||||
Value::Null
|
||||
}
|
||||
} else if type_name == "TIMESTAMP" {
|
||||
if let Ok(ts) = row.try_get::<chrono::NaiveDateTime, _>(name) {
|
||||
json!(ts.to_string())
|
||||
} else {
|
||||
Value::Null
|
||||
}
|
||||
} else {
|
||||
// Fallback for types that can't be directly read as String
|
||||
match row.try_get::<String, _>(name) {
|
||||
Ok(s) => json!(s),
|
||||
Err(_) => match row.try_get::<Value, _>(name) {
|
||||
Ok(v) => v,
|
||||
Err(_) => Value::Null,
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
obj.insert(name.to_string(), val);
|
||||
}
|
||||
json_rows.push(Value::Object(obj));
|
||||
}
|
||||
json_rows
|
||||
}
|
||||
|
||||
pub async fn insert_row(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
|
||||
// Start transaction for RLS
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Set RLS variables
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set role: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(claims) = &auth_ctx.claims {
|
||||
let sub_query = "SELECT set_config('request.jwt.claim.sub', $1, true)";
|
||||
sqlx::query(sub_query)
|
||||
.bind(&claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(email) = &claims.email {
|
||||
let email_query = "SELECT set_config('request.jwt.claim.email', $1, true)";
|
||||
sqlx::query(email_query)
|
||||
.bind(email)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let rows_to_insert = match payload {
|
||||
Value::Array(arr) => arr,
|
||||
Value::Object(obj) => vec![Value::Object(obj)],
|
||||
_ => return Err((StatusCode::BAD_REQUEST, "Payload must be a JSON object or array".to_string())),
|
||||
};
|
||||
|
||||
if rows_to_insert.is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "Payload empty".to_string()));
|
||||
}
|
||||
|
||||
// Use keys from the first row as the columns
|
||||
let first_row = rows_to_insert[0].as_object().ok_or((StatusCode::BAD_REQUEST, "Rows must be objects".to_string()))?;
|
||||
let columns: Vec<String> = first_row.keys().cloned().collect();
|
||||
|
||||
if columns.is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "No columns to insert".to_string()));
|
||||
}
|
||||
|
||||
let col_str = columns
|
||||
.iter()
|
||||
.map(|c| format!("\"{}\"", c))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let mut values_sql = Vec::new();
|
||||
let mut bind_values: Vec<SqlValue> = Vec::new();
|
||||
let mut param_index = 1;
|
||||
|
||||
for row in rows_to_insert {
|
||||
let obj = row.as_object().ok_or((StatusCode::BAD_REQUEST, "Rows must be objects".to_string()))?;
|
||||
let mut row_placeholders = Vec::new();
|
||||
|
||||
for col in &columns {
|
||||
row_placeholders.push(format!("${}", param_index));
|
||||
param_index += 1;
|
||||
|
||||
// Get value or Null
|
||||
let val = obj.get(col).cloned().unwrap_or(Value::Null);
|
||||
bind_values.push(json_value_to_sql_value(val));
|
||||
}
|
||||
values_sql.push(format!("({})", row_placeholders.join(", ")));
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
"INSERT INTO {} ({}) VALUES {} RETURNING *",
|
||||
table, col_str, values_sql.join(", ")
|
||||
);
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
for v in bind_values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
SqlValue::Int(n) => query = query.bind(n),
|
||||
SqlValue::Float(f) => query = query.bind(f),
|
||||
SqlValue::Bool(b) => query = query.bind(b),
|
||||
SqlValue::Uuid(u) => query = query.bind(u),
|
||||
SqlValue::Json(j) => query = query.bind(j),
|
||||
SqlValue::Null => query = query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
let rows = query
|
||||
.fetch_all(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let json_rows = rows_to_json(rows);
|
||||
Ok((StatusCode::CREATED, Json(json_rows)))
|
||||
}
|
||||
|
||||
|
||||
pub async fn delete_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let query_params = QueryParams::parse(params);
|
||||
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set role: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(claims) = &auth_ctx.claims {
|
||||
let sub_query = "SELECT set_config('request.jwt.claim.sub', $1, true)";
|
||||
sqlx::query(sub_query)
|
||||
.bind(&claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(email) = &claims.email {
|
||||
let email_query = "SELECT set_config('request.jwt.claim.email', $1, true)";
|
||||
sqlx::query(email_query)
|
||||
.bind(email)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sql = format!("DELETE FROM {}", table);
|
||||
let mut values: Vec<SqlValue> = Vec::new();
|
||||
let mut param_index = 1;
|
||||
|
||||
if !query_params.filters.is_empty() {
|
||||
sql.push_str(" WHERE ");
|
||||
let conditions: Vec<String> = query_params
|
||||
.filters
|
||||
.iter()
|
||||
.map(|f| build_filter_clause(f, &mut param_index, &mut values))
|
||||
.collect();
|
||||
sql.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
for v in values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
SqlValue::Int(n) => query = query.bind(n),
|
||||
SqlValue::Float(f) => query = query.bind(f),
|
||||
SqlValue::Bool(b) => query = query.bind(b),
|
||||
SqlValue::Uuid(u) => query = query.bind(u),
|
||||
SqlValue::Json(j) => query = query.bind(j),
|
||||
SqlValue::Null => query = query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
query
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Delete Rows error: SQL={}, Error={:?}", sql, e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
|
||||
})?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
pub async fn update_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
|
||||
let query_params = QueryParams::parse(params);
|
||||
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set role: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(claims) = &auth_ctx.claims {
|
||||
let sub_query = "SELECT set_config('request.jwt.claim.sub', $1, true)";
|
||||
sqlx::query(sub_query)
|
||||
.bind(&claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(email) = &claims.email {
|
||||
let email_query = "SELECT set_config('request.jwt.claim.email', $1, true)";
|
||||
sqlx::query(email_query)
|
||||
.bind(email)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let obj = payload.as_object().ok_or((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Payload must be a JSON object".to_string(),
|
||||
))?;
|
||||
if obj.is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "Payload empty".to_string()));
|
||||
}
|
||||
|
||||
let mut final_sql = format!("UPDATE {} SET ", table);
|
||||
let mut final_values: Vec<SqlValue> = Vec::new();
|
||||
let mut p_idx = 1;
|
||||
|
||||
let mut sets = Vec::new();
|
||||
for (k, v) in obj {
|
||||
sets.push(format!("\"{}\" = ${}", k, p_idx));
|
||||
final_values.push(json_value_to_sql_value(v.clone()));
|
||||
p_idx += 1;
|
||||
}
|
||||
final_sql.push_str(&sets.join(", "));
|
||||
|
||||
if !query_params.filters.is_empty() {
|
||||
final_sql.push_str(" WHERE ");
|
||||
let mut conds = Vec::new();
|
||||
|
||||
for f in &query_params.filters {
|
||||
conds.push(build_filter_clause(f, &mut p_idx, &mut final_values));
|
||||
}
|
||||
final_sql.push_str(&conds.join(" AND "));
|
||||
}
|
||||
|
||||
let mut query = sqlx::query(&final_sql);
|
||||
|
||||
for v in final_values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
SqlValue::Int(n) => query = query.bind(n),
|
||||
SqlValue::Float(f) => query = query.bind(f),
|
||||
SqlValue::Bool(b) => query = query.bind(b),
|
||||
SqlValue::Uuid(u) => query = query.bind(u),
|
||||
SqlValue::Json(j) => query = query.bind(j),
|
||||
SqlValue::Null => query = query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
query
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Update Rows error: SQL={}, Error={:?}", final_sql, e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
|
||||
})?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
pub async fn rpc(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(function): Path<String>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&function) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid function name".to_string()));
|
||||
}
|
||||
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set role: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(claims) = &auth_ctx.claims {
|
||||
let sub_query = "SELECT set_config('request.jwt.claim.sub', $1, true)";
|
||||
sqlx::query(sub_query)
|
||||
.bind(&claims.sub)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(email) = &claims.email {
|
||||
let email_query = "SELECT set_config('request.jwt.claim.email', $1, true)";
|
||||
sqlx::query(email_query)
|
||||
.bind(email)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to set claims: {}", e),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let obj = payload.as_object().ok_or((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Payload must be a JSON object".to_string(),
|
||||
))?;
|
||||
|
||||
let mut args = Vec::new();
|
||||
let mut values: Vec<SqlValue> = Vec::new();
|
||||
let mut p_idx = 1;
|
||||
|
||||
for (k, v) in obj {
|
||||
if !is_valid_identifier(k) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid argument name".to_string()));
|
||||
}
|
||||
args.push(format!("{} => ${}", k, p_idx));
|
||||
values.push(json_value_to_sql_value(v.clone()));
|
||||
p_idx += 1;
|
||||
}
|
||||
|
||||
let sql = if args.is_empty() {
|
||||
format!("SELECT * FROM {}()", function)
|
||||
} else {
|
||||
format!("SELECT * FROM {}({})", function, args.join(", "))
|
||||
};
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
for v in values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
SqlValue::Int(n) => query = query.bind(n),
|
||||
SqlValue::Float(f) => query = query.bind(f),
|
||||
SqlValue::Bool(b) => query = query.bind(b),
|
||||
SqlValue::Uuid(u) => query = query.bind(u),
|
||||
SqlValue::Json(j) => query = query.bind(j),
|
||||
SqlValue::Null => query = query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
let rows = query
|
||||
.fetch_all(&mut *tx)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
tx.commit()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let json_rows = rows_to_json(rows);
|
||||
Ok(Json(json_rows))
|
||||
}
|
||||
|
||||
fn is_valid_identifier(s: &str) -> bool {
|
||||
s.chars().all(|c| c.is_alphanumeric() || c == '_') && !s.is_empty()
|
||||
}
|
||||
Reference in New Issue
Block a user