chore: full stack stability and migration fixes, plus react UI progress
This commit is contained in:
@@ -2,10 +2,12 @@ use crate::parser::{Operator, QueryParams, SelectNode, FilterNode};
|
||||
use auth::AuthContext;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json},
|
||||
Extension,
|
||||
};
|
||||
use crate::schema_cache::{SchemaCache, ForeignKeyInfo};
|
||||
use std::sync::Arc;
|
||||
use common::Config;
|
||||
use futures::future::BoxFuture;
|
||||
use serde_json::{json, Value};
|
||||
@@ -13,10 +15,14 @@ use sqlx::{Column, PgPool, Row, TypeInfo};
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
|
||||
type SelectClauseFuture<'a> = BoxFuture<'a, Result<(String, Vec<String>), (StatusCode, String)>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DataState {
|
||||
pub db: PgPool,
|
||||
pub replica_pool: Option<PgPool>,
|
||||
pub config: Config,
|
||||
pub cache: Arc<SchemaCache>,
|
||||
}
|
||||
|
||||
const ALLOWED_ROLES: &[&str] = &["anon", "authenticated", "service_role"];
|
||||
@@ -65,13 +71,45 @@ fn json_value_to_sql_value(v: Value) -> SqlValue {
|
||||
|
||||
pub async fn get_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
headers: HeaderMap,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let query_params = QueryParams::parse(params);
|
||||
let mut query_params = QueryParams::parse(params);
|
||||
|
||||
// Parse Range header: Range: items=0-9
|
||||
let range = headers.get("Range")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| {
|
||||
let s = s.strip_prefix("items=").unwrap_or(s);
|
||||
let parts: Vec<&str> = s.split('-').collect();
|
||||
if parts.len() == 2 {
|
||||
let start = parts[0].parse::<usize>().ok()?;
|
||||
let end = parts[1].parse::<usize>().ok()?;
|
||||
Some((start, end))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
if let Some((start, end)) = range {
|
||||
query_params.offset = Some(start);
|
||||
query_params.limit = Some(end - start + 1);
|
||||
}
|
||||
|
||||
// Parse Prefer header for count
|
||||
let want_count = headers.get("Prefer")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.contains("count=exact"))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Parse Accept header for single object
|
||||
let want_single = headers.get("Accept")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.contains("vnd.pgrst.object+json"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
@@ -83,6 +121,14 @@ pub async fn get_rows(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Handle Schema selection
|
||||
if let Some(profile) = headers.get("Accept-Profile").and_then(|v| v.to_str().ok()) {
|
||||
if is_valid_identifier(profile) {
|
||||
let schema_query = format!("SET LOCAL search_path TO {}, public", profile);
|
||||
sqlx::query(&schema_query).execute(&mut *tx).await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
// Set RLS variables
|
||||
validate_role(&auth_ctx.role)?;
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
@@ -126,19 +172,22 @@ pub async fn get_rows(
|
||||
|
||||
// --- Construct Query ---
|
||||
// Use pool for schema introspection to avoid borrowing tx
|
||||
let select_clause = build_select_clause(&query_params.select, &table, &db).await?;
|
||||
let (select_clause, extra_filters) = build_select_clause(&query_params.select, &table, &db, state.cache.clone()).await?;
|
||||
|
||||
let mut sql = format!("SELECT {} FROM {}", select_clause, table);
|
||||
let mut values: Vec<SqlValue> = Vec::new();
|
||||
let mut param_index = 1;
|
||||
|
||||
if !query_params.filters.is_empty() {
|
||||
let all_filters = &query_params.filters;
|
||||
|
||||
if !all_filters.is_empty() || !extra_filters.is_empty() {
|
||||
sql.push_str(" WHERE ");
|
||||
let conditions: Vec<String> = query_params
|
||||
.filters
|
||||
let mut conditions: Vec<String> = all_filters
|
||||
.iter()
|
||||
.map(|f| build_filter_clause(f, &mut param_index, &mut values))
|
||||
.collect();
|
||||
|
||||
conditions.extend(extra_filters);
|
||||
sql.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
|
||||
@@ -183,7 +232,70 @@ pub async fn get_rows(
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let json_rows = rows_to_json(rows);
|
||||
Ok(Json(json_rows))
|
||||
let row_count = json_rows.len();
|
||||
|
||||
let mut total_count = None;
|
||||
if want_count {
|
||||
let mut count_sql = format!("SELECT COUNT(*) FROM {}", table);
|
||||
let mut count_values: Vec<SqlValue> = Vec::new();
|
||||
let mut count_param_index = 1;
|
||||
|
||||
if !query_params.filters.is_empty() {
|
||||
count_sql.push_str(" WHERE ");
|
||||
let conditions: Vec<String> = query_params
|
||||
.filters
|
||||
.iter()
|
||||
.map(|f| build_filter_clause(f, &mut count_param_index, &mut count_values))
|
||||
.collect();
|
||||
count_sql.push_str(&conditions.join(" AND "));
|
||||
}
|
||||
|
||||
let mut count_query = sqlx::query_as::<_, (i64,)>(&count_sql);
|
||||
for v in count_values {
|
||||
count_query = match v {
|
||||
SqlValue::String(s) => count_query.bind(s),
|
||||
SqlValue::Int(n) => count_query.bind(n),
|
||||
SqlValue::Float(f) => count_query.bind(f),
|
||||
SqlValue::Bool(b) => count_query.bind(b),
|
||||
SqlValue::Uuid(u) => count_query.bind(u),
|
||||
SqlValue::Json(j) => count_query.bind(j),
|
||||
SqlValue::Null => count_query.bind(Option::<String>::None),
|
||||
};
|
||||
}
|
||||
|
||||
if let Ok(count_row) = count_query.fetch_one(&db).await {
|
||||
total_count = Some(count_row.0);
|
||||
}
|
||||
}
|
||||
|
||||
if want_single {
|
||||
if row_count > 1 {
|
||||
return Err((StatusCode::NOT_ACCEPTABLE, "Multiple rows returned for single object request".to_string()));
|
||||
}
|
||||
if row_count == 0 {
|
||||
return Err((StatusCode::NOT_ACCEPTABLE, "No rows returned for single object request".to_string()));
|
||||
}
|
||||
|
||||
let mut response = Json(json_rows[0].clone()).into_response();
|
||||
if let Some(total) = total_count {
|
||||
let range_val = format!("0-0/{}", total);
|
||||
if let Ok(hv) = range_val.parse() {
|
||||
response.headers_mut().insert("Content-Range", hv);
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
} else {
|
||||
let mut response = Json(json_rows).into_response();
|
||||
if let Some(total) = total_count {
|
||||
let start = query_params.offset.unwrap_or(0);
|
||||
let end = if row_count == 0 { start } else { start + row_count - 1 };
|
||||
let range_val = format!("{}-{}/{}", start, end, total);
|
||||
if let Ok(hv) = range_val.parse() {
|
||||
response.headers_mut().insert("Content-Range", hv);
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_filter_clause(
|
||||
@@ -241,6 +353,10 @@ fn build_filter_clause(
|
||||
format!("({})", clauses.join(" AND "))
|
||||
}
|
||||
}
|
||||
FilterNode::Not(inner) => {
|
||||
let inner_clause = build_filter_clause(inner, param_index, values);
|
||||
format!("NOT ({})", inner_clause)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,13 +365,15 @@ fn build_select_clause<'a>(
|
||||
nodes: &'a [SelectNode],
|
||||
table: &'a str,
|
||||
pool: &'a PgPool,
|
||||
) -> BoxFuture<'a, Result<String, (StatusCode, String)>> {
|
||||
cache: Arc<SchemaCache>,
|
||||
) -> SelectClauseFuture<'a> {
|
||||
Box::pin(async move {
|
||||
if nodes.is_empty() {
|
||||
return Ok("*".to_string());
|
||||
return Ok(("*".to_string(), vec![]));
|
||||
}
|
||||
|
||||
let mut clauses = Vec::new();
|
||||
let mut filters = Vec::new();
|
||||
for node in nodes {
|
||||
match node {
|
||||
SelectNode::Column(c) => {
|
||||
@@ -265,20 +383,19 @@ fn build_select_clause<'a>(
|
||||
clauses.push(format!("\"{}\"", c));
|
||||
}
|
||||
}
|
||||
SelectNode::Relation(rel, inner) => {
|
||||
let fk_info = find_foreign_key(table, rel, pool)
|
||||
SelectNode::Relation(rel, inner_nodes, is_inner) => {
|
||||
let fk_info = find_foreign_key(table, rel, pool, cache.clone())
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
if let Some((local_col, foreign_table, foreign_col)) = fk_info {
|
||||
let inner_select = if inner.is_empty() {
|
||||
"*".to_string()
|
||||
let (inner_select, inner_filters) = if inner_nodes.is_empty() {
|
||||
("*".to_string(), vec![])
|
||||
} else {
|
||||
build_select_clause(inner, &foreign_table, pool).await?
|
||||
build_select_clause(inner_nodes, &foreign_table, pool, cache.clone()).await?
|
||||
};
|
||||
|
||||
let subquery = if foreign_col.starts_with("REV:") {
|
||||
let actual_foreign_col = &foreign_col[4..];
|
||||
let subquery = if let Some(actual_foreign_col) = foreign_col.strip_prefix("REV:") {
|
||||
format!(
|
||||
"(SELECT json_agg(t) FROM (SELECT {} FROM {} WHERE {} = {}.{}) t) as \"{}\"",
|
||||
inner_select, foreign_table, actual_foreign_col, table, local_col, rel
|
||||
@@ -290,6 +407,24 @@ fn build_select_clause<'a>(
|
||||
)
|
||||
};
|
||||
clauses.push(subquery);
|
||||
|
||||
// Merge inner filters (for nested !inner)
|
||||
filters.extend(inner_filters);
|
||||
|
||||
if *is_inner {
|
||||
let exists_filter = if let Some(actual_foreign_col) = foreign_col.strip_prefix("REV:") {
|
||||
format!(
|
||||
"EXISTS (SELECT 1 FROM {} WHERE {} = {}.{})",
|
||||
foreign_table, actual_foreign_col, table, local_col
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"EXISTS (SELECT 1 FROM {} WHERE {} = {}.{})",
|
||||
foreign_table, foreign_col, table, local_col
|
||||
)
|
||||
};
|
||||
filters.push(exists_filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -299,7 +434,7 @@ fn build_select_clause<'a>(
|
||||
return Err((StatusCode::BAD_REQUEST, "No valid columns selected".to_string()));
|
||||
}
|
||||
|
||||
Ok(clauses.join(", "))
|
||||
Ok((clauses.join(", "), filters))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -308,13 +443,11 @@ async fn find_foreign_key(
|
||||
table: &str,
|
||||
relation: &str,
|
||||
pool: &PgPool,
|
||||
cache: Arc<SchemaCache>,
|
||||
) -> Result<Option<(String, String, String)>, String> {
|
||||
// Basic introspection to find FK.
|
||||
// We look for a table named `relation` or a column named `relation_id`.
|
||||
// PostgREST logic is complex, here's a simplified version:
|
||||
// 1. Check if `relation` is a table name.
|
||||
// 2. Find FK between `table` and `relation`.
|
||||
|
||||
if let Some(cached) = cache.get_fk(table, relation).await {
|
||||
return Ok(cached.map(|c| (c.local_col, c.foreign_table, c.foreign_col)));
|
||||
}
|
||||
let query = r#"
|
||||
SELECT
|
||||
kcu.column_name as local_col,
|
||||
@@ -341,10 +474,14 @@ async fn find_foreign_key(
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if let Some(r) = row {
|
||||
cache.insert_fk(table, relation, Some(ForeignKeyInfo {
|
||||
local_col: r.0.clone(),
|
||||
foreign_table: r.1.clone(),
|
||||
foreign_col: r.2.clone(),
|
||||
})).await;
|
||||
return Ok(Some(r));
|
||||
}
|
||||
|
||||
// Try reverse (many-to-one): relation table has FK to our table
|
||||
let reverse_query = r#"
|
||||
SELECT
|
||||
ccu.column_name as local_col,
|
||||
@@ -371,9 +508,6 @@ async fn find_foreign_key(
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if let Some(r) = row {
|
||||
// For reverse relations (one-to-many), we want to aggregate them.
|
||||
// Returning a tuple that signifies reverse relation might be tricky with the same signature.
|
||||
// Let's hack it: return foreign_col as "REV:foreign_col".
|
||||
return Ok(Some((r.0, r.1, format!("REV:{}", r.2))));
|
||||
}
|
||||
|
||||
@@ -425,13 +559,11 @@ fn rows_to_json(rows: Vec<sqlx::postgres::PgRow>) -> Vec<Value> {
|
||||
} else if type_name == "VECTOR" {
|
||||
match row.try_get::<String, _>(name) {
|
||||
Ok(s) => {
|
||||
// Parse string "[1,2,3]" to JSON array
|
||||
serde_json::from_str(&s).unwrap_or(json!(s))
|
||||
},
|
||||
Err(_) => Value::Null,
|
||||
}
|
||||
} else {
|
||||
// Fallback for types that can't be directly read as String
|
||||
match row.try_get::<String, _>(name) {
|
||||
Ok(s) => json!(s),
|
||||
Err(_) => match row.try_get::<Value, _>(name) {
|
||||
@@ -449,24 +581,35 @@ fn rows_to_json(rows: Vec<sqlx::postgres::PgRow>) -> Vec<Value> {
|
||||
}
|
||||
|
||||
pub async fn insert_row(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
State(_state): State<DataState>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
headers: HeaderMap,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
|
||||
// Start transaction for RLS
|
||||
let is_upsert = headers.get("Prefer")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.contains("resolution=merge-duplicates"))
|
||||
.unwrap_or(false);
|
||||
|
||||
let mut tx = db
|
||||
.begin()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Set RLS variables
|
||||
// Handle Schema selection
|
||||
if let Some(profile) = headers.get("Content-Profile").and_then(|v| v.to_str().ok()) {
|
||||
if is_valid_identifier(profile) {
|
||||
let schema_query = format!("SET LOCAL search_path TO {}, public", profile);
|
||||
sqlx::query(&schema_query).execute(&mut *tx).await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
validate_role(&auth_ctx.role)?;
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
@@ -517,7 +660,6 @@ pub async fn insert_row(
|
||||
return Err((StatusCode::BAD_REQUEST, "Payload empty".to_string()));
|
||||
}
|
||||
|
||||
// Use keys from the first row as the columns
|
||||
let first_row = rows_to_insert[0].as_object().ok_or((StatusCode::BAD_REQUEST, "Rows must be objects".to_string()))?;
|
||||
let columns: Vec<String> = first_row.keys().cloned().collect();
|
||||
|
||||
@@ -542,21 +684,36 @@ pub async fn insert_row(
|
||||
for col in &columns {
|
||||
row_placeholders.push(format!("${}", param_index));
|
||||
param_index += 1;
|
||||
|
||||
// Get value or Null
|
||||
let val = obj.get(col).cloned().unwrap_or(Value::Null);
|
||||
bind_values.push(json_value_to_sql_value(val));
|
||||
}
|
||||
values_sql.push(format!("({})", row_placeholders.join(", ")));
|
||||
}
|
||||
|
||||
let sql = format!(
|
||||
"INSERT INTO {} ({}) VALUES {} RETURNING *",
|
||||
let mut sql = format!(
|
||||
"INSERT INTO {} ({}) VALUES {} ",
|
||||
table, col_str, values_sql.join(", ")
|
||||
);
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
if is_upsert {
|
||||
// Simplified upsert: assume 'id' is the conflict target if it exists, otherwise use first column
|
||||
let conflict_target = if columns.contains(&"id".to_string()) { "id" } else { &columns[0] };
|
||||
let update_sets = columns.iter()
|
||||
.filter(|c| *c != conflict_target)
|
||||
.map(|c| format!("\"{}\" = EXCLUDED.\"{}\"", c, c))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
if update_sets.is_empty() {
|
||||
sql.push_str(&format!("ON CONFLICT (\"{}\") DO NOTHING ", conflict_target));
|
||||
} else {
|
||||
sql.push_str(&format!("ON CONFLICT (\"{}\") DO UPDATE SET {} ", conflict_target, update_sets));
|
||||
}
|
||||
}
|
||||
|
||||
sql.push_str("RETURNING *");
|
||||
|
||||
let mut query = sqlx::query(&sql);
|
||||
for v in bind_values {
|
||||
match v {
|
||||
SqlValue::String(s) => query = query.bind(s),
|
||||
@@ -584,13 +741,13 @@ pub async fn insert_row(
|
||||
|
||||
|
||||
pub async fn delete_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
State(_state): State<DataState>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
headers: HeaderMap,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
let query_params = QueryParams::parse(params);
|
||||
|
||||
if !is_valid_identifier(&table) {
|
||||
@@ -602,6 +759,14 @@ pub async fn delete_rows(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Handle Schema selection
|
||||
if let Some(profile) = headers.get("Content-Profile").and_then(|v| v.to_str().ok()) {
|
||||
if is_valid_identifier(profile) {
|
||||
let schema_query = format!("SET LOCAL search_path TO {}, public", profile);
|
||||
sqlx::query(&schema_query).execute(&mut *tx).await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
validate_role(&auth_ctx.role)?;
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
@@ -685,14 +850,14 @@ pub async fn delete_rows(
|
||||
}
|
||||
|
||||
pub async fn update_rows(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
State(_state): State<DataState>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
headers: HeaderMap,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(table): Path<String>,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&table) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid table name".to_string()));
|
||||
}
|
||||
@@ -704,6 +869,14 @@ pub async fn update_rows(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Handle Schema selection
|
||||
if let Some(profile) = headers.get("Content-Profile").and_then(|v| v.to_str().ok()) {
|
||||
if is_valid_identifier(profile) {
|
||||
let schema_query = format!("SET LOCAL search_path TO {}, public", profile);
|
||||
sqlx::query(&schema_query).execute(&mut *tx).await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
validate_role(&auth_ctx.role)?;
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
@@ -806,9 +979,11 @@ pub async fn update_rows(
|
||||
pub async fn rpc(
|
||||
State(state): State<DataState>,
|
||||
db: Option<Extension<PgPool>>,
|
||||
headers: HeaderMap,
|
||||
Extension(auth_ctx): Extension<AuthContext>,
|
||||
Path(function): Path<String>,
|
||||
Json(payload): Json<Value>,
|
||||
Query(query_params): Query<HashMap<String, String>>,
|
||||
payload: Option<Json<Value>>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let db = db.map(|Extension(p)| p).unwrap_or_else(|| state.db.clone());
|
||||
if !is_valid_identifier(&function) {
|
||||
@@ -820,6 +995,14 @@ pub async fn rpc(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Handle Schema selection
|
||||
if let Some(profile) = headers.get("Content-Profile").and_then(|v| v.to_str().ok()) {
|
||||
if is_valid_identifier(profile) {
|
||||
let schema_query = format!("SET LOCAL search_path TO {}, public", profile);
|
||||
sqlx::query(&schema_query).execute(&mut *tx).await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
}
|
||||
}
|
||||
|
||||
validate_role(&auth_ctx.role)?;
|
||||
let role_query = format!("SET LOCAL role = '{}'", auth_ctx.role);
|
||||
sqlx::query(&role_query)
|
||||
@@ -860,21 +1043,30 @@ pub async fn rpc(
|
||||
}
|
||||
}
|
||||
|
||||
let obj = payload.as_object().ok_or((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Payload must be a JSON object".to_string(),
|
||||
))?;
|
||||
let mut args_map = serde_json::Map::new();
|
||||
|
||||
// 1. Params from URL
|
||||
for (k, v) in query_params {
|
||||
args_map.insert(k, Value::String(v));
|
||||
}
|
||||
|
||||
// 2. Params from JSON body
|
||||
if let Some(Json(Value::Object(obj))) = payload {
|
||||
for (k, v) in obj {
|
||||
args_map.insert(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
let mut args = Vec::new();
|
||||
let mut values: Vec<SqlValue> = Vec::new();
|
||||
let mut p_idx = 1;
|
||||
|
||||
for (k, v) in obj {
|
||||
if !is_valid_identifier(k) {
|
||||
for (k, v) in args_map {
|
||||
if !is_valid_identifier(&k) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid argument name".to_string()));
|
||||
}
|
||||
args.push(format!("{} => ${}", k, p_idx));
|
||||
values.push(json_value_to_sql_value(v.clone()));
|
||||
values.push(json_value_to_sql_value(v));
|
||||
p_idx += 1;
|
||||
}
|
||||
|
||||
@@ -968,4 +1160,30 @@ mod tests {
|
||||
assert!(!is_valid_identifier(""));
|
||||
assert!(!is_valid_identifier("table.name"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_replica_routing_logic() {
|
||||
use axum::http::HeaderMap;
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
// Default: read-only (SELECT) -> replica (implied by default true)
|
||||
let is_read_only = headers.get("x-read-replica")
|
||||
.map(|v| v.to_str().unwrap_or("false") == "true")
|
||||
.unwrap_or(true);
|
||||
assert!(is_read_only);
|
||||
|
||||
// Explicitly opt-out of replica
|
||||
headers.insert("x-read-replica", "false".parse().unwrap());
|
||||
let is_read_only = headers.get("x-read-replica")
|
||||
.map(|v| v.to_str().unwrap_or("false") == "true")
|
||||
.unwrap_or(true);
|
||||
assert!(!is_read_only);
|
||||
|
||||
// Explicitly opt-in to replica
|
||||
headers.insert("x-read-replica", "true".parse().unwrap());
|
||||
let is_read_only = headers.get("x-read-replica")
|
||||
.map(|v| v.to_str().unwrap_or("false") == "true")
|
||||
.unwrap_or(true);
|
||||
assert!(is_read_only);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
pub mod handlers;
|
||||
pub mod parser;
|
||||
pub mod schema_cache;
|
||||
#[cfg(test)]
|
||||
pub mod parser_m4_tests;
|
||||
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
@@ -9,7 +12,7 @@ use handlers::DataState;
|
||||
|
||||
pub fn router() -> Router<DataState> {
|
||||
Router::new()
|
||||
.route("/rpc/:function", post(handlers::rpc))
|
||||
.route("/rpc/:function", post(handlers::rpc).get(handlers::rpc))
|
||||
.route(
|
||||
"/:table",
|
||||
get(handlers::get_rows)
|
||||
|
||||
@@ -12,6 +12,9 @@ pub enum Operator {
|
||||
Ilike,
|
||||
In,
|
||||
Is,
|
||||
Contains, // cs.
|
||||
ContainedBy, // cd.
|
||||
TextSearch, // fts.
|
||||
}
|
||||
|
||||
impl Operator {
|
||||
@@ -27,6 +30,9 @@ impl Operator {
|
||||
"ilike" => Some(Operator::Ilike),
|
||||
"in" => Some(Operator::In),
|
||||
"is" => Some(Operator::Is),
|
||||
"cs" => Some(Operator::Contains),
|
||||
"cd" => Some(Operator::ContainedBy),
|
||||
"fts" => Some(Operator::TextSearch),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -43,6 +49,9 @@ impl Operator {
|
||||
Operator::Ilike => "ILIKE",
|
||||
Operator::In => "IN",
|
||||
Operator::Is => "IS",
|
||||
Operator::Contains => "@>",
|
||||
Operator::ContainedBy => "<@",
|
||||
Operator::TextSearch => "@@",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -62,7 +71,7 @@ pub enum Direction {
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SelectNode {
|
||||
Column(String),
|
||||
Relation(String, Vec<SelectNode>),
|
||||
Relation(String, Vec<SelectNode>, bool), // bool is is_inner
|
||||
}
|
||||
|
||||
impl SelectNode {
|
||||
@@ -98,16 +107,26 @@ impl SelectNode {
|
||||
nodes
|
||||
}
|
||||
|
||||
fn parse_single(s: &str) -> Self {
|
||||
let s = s.trim();
|
||||
if let Some(idx) = s.find('(') {
|
||||
if s.ends_with(')') {
|
||||
let relation = &s[..idx];
|
||||
let inner = &s[idx + 1..s.len() - 1];
|
||||
return SelectNode::Relation(relation.to_string(), Self::parse(inner));
|
||||
fn parse_single(input: &str) -> Self {
|
||||
let input = input.trim();
|
||||
if input.contains('(') {
|
||||
let parts: Vec<&str> = input.splitn(2, '(').collect();
|
||||
let mut rel_part = parts[0].trim();
|
||||
let mut is_inner = false;
|
||||
|
||||
if rel_part.ends_with("!inner") {
|
||||
is_inner = true;
|
||||
rel_part = &rel_part[..rel_part.len()-6];
|
||||
} else if rel_part.ends_with("!left") {
|
||||
is_inner = false;
|
||||
rel_part = &rel_part[..rel_part.len()-5];
|
||||
}
|
||||
|
||||
let inner_str = &parts[1][..parts[1].len() - 1];
|
||||
SelectNode::Relation(rel_part.to_string(), Self::parse(inner_str), is_inner)
|
||||
} else {
|
||||
SelectNode::Column(input.to_string())
|
||||
}
|
||||
SelectNode::Column(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +139,7 @@ pub enum FilterNode {
|
||||
},
|
||||
Or(Vec<FilterNode>),
|
||||
And(Vec<FilterNode>),
|
||||
Not(Box<FilterNode>),
|
||||
}
|
||||
|
||||
impl FilterNode {
|
||||
@@ -157,6 +177,8 @@ impl FilterNode {
|
||||
} else {
|
||||
Some(FilterNode::And(nodes))
|
||||
}
|
||||
} else if let Some(inner_value) = value.strip_prefix("not.") {
|
||||
FilterNode::parse(key, inner_value).map(|inner| FilterNode::Not(Box::new(inner)))
|
||||
} else {
|
||||
// Check for filters: column=operator.value or column=value (eq implicit)
|
||||
let parts: Vec<&str> = value.splitn(2, '.').collect();
|
||||
|
||||
66
data_api/src/parser_m4_tests.rs
Normal file
66
data_api/src/parser_m4_tests.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::parser::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_or_filter() {
|
||||
let filters = FilterNode::parse("or", "(title.eq.Hello,title.eq.World)");
|
||||
assert!(matches!(filters, Some(FilterNode::Or(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_not_filter() {
|
||||
let filters = FilterNode::parse("status", "not.eq.draft");
|
||||
if let Some(FilterNode::Not(inner)) = filters {
|
||||
if let FilterNode::Condition { column, operator, value } = *inner {
|
||||
assert_eq!(column, "status");
|
||||
assert_eq!(operator, Operator::Eq);
|
||||
assert_eq!(value, "draft");
|
||||
} else {
|
||||
panic!("Inner should be a condition");
|
||||
}
|
||||
} else {
|
||||
panic!("Expected Not filter, got {:?}", filters);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_contains_jsonb() {
|
||||
let filters = FilterNode::parse("tags", "cs.{a,b}");
|
||||
assert!(matches!(filters, Some(FilterNode::Condition { operator: Operator::Contains, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_contained_by() {
|
||||
let filters = FilterNode::parse("tags", "cd.{a,b,c}");
|
||||
assert!(matches!(filters, Some(FilterNode::Condition { operator: Operator::ContainedBy, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_text_search() {
|
||||
let filters = FilterNode::parse("content", "fts.hello+world");
|
||||
assert!(matches!(filters, Some(FilterNode::Condition { operator: Operator::TextSearch, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_select_with_nesting() {
|
||||
let select = SelectNode::parse("*,author:users(name,posts(*))");
|
||||
assert_eq!(select.len(), 2);
|
||||
assert!(matches!(select[0], SelectNode::Column(_)));
|
||||
if let SelectNode::Relation(rel, inner, _) = &select[1] {
|
||||
assert_eq!(rel, "author:users");
|
||||
assert_eq!(inner.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_inner_join() {
|
||||
let select = SelectNode::parse("id,profiles!inner(username)");
|
||||
assert_eq!(select.len(), 2);
|
||||
if let SelectNode::Relation(rel, inner, is_inner) = &select[1] {
|
||||
assert_eq!(rel, "profiles");
|
||||
assert!(is_inner);
|
||||
assert_eq!(inner.len(), 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
43
data_api/src/schema_cache.rs
Normal file
43
data_api/src/schema_cache.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use moka::future::Cache;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ForeignKeyInfo {
|
||||
pub local_col: String,
|
||||
pub foreign_table: String,
|
||||
pub foreign_col: String,
|
||||
}
|
||||
|
||||
pub struct SchemaCache {
|
||||
// Key: (table_name, relation_name)
|
||||
fk_cache: Cache<(String, String), Option<ForeignKeyInfo>>,
|
||||
}
|
||||
|
||||
impl Default for SchemaCache {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SchemaCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
fk_cache: Cache::builder()
|
||||
.max_capacity(1000)
|
||||
.time_to_live(Duration::from_secs(3600))
|
||||
.build(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_fk(&self, table: &str, relation: &str) -> Option<Option<ForeignKeyInfo>> {
|
||||
self.fk_cache.get(&(table.to_string(), relation.to_string())).await
|
||||
}
|
||||
|
||||
pub async fn insert_fk(&self, table: &str, relation: &str, info: Option<ForeignKeyInfo>) {
|
||||
self.fk_cache.insert((table.to_string(), relation.to_string()), info).await;
|
||||
}
|
||||
|
||||
pub async fn invalidate_all(&self) {
|
||||
self.fk_cache.invalidate_all();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user