added more support for supabase-js

This commit is contained in:
2026-03-12 10:18:52 +02:00
parent c0792f2e1d
commit 6708cf28a7
62 changed files with 6563 additions and 526 deletions

View File

@@ -14,9 +14,11 @@ sqlx = { workspace = true }
tracing = { workspace = true }
futures = { workspace = true }
uuid = { workspace = true }
tokio-postgres = "0.7"
tokio-postgres = { version = "0.7", features = ["array-impls", "with-uuid-1", "with-serde_json-1", "with-chrono-0_4"] }
postgres-types = "0.2"
postgres-protocol = "0.6"
anyhow = { workspace = true }
bytes = "1.0"
jsonwebtoken = { workspace = true }
chrono.workspace = true
dashmap = "5.5"

View File

@@ -3,9 +3,11 @@ pub mod ws;
use axum::Router;
use common::Config;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::PgPool;
use std::sync::Arc;
use tokio::sync::broadcast;
pub use ws::{router, RealtimeState};
@@ -22,12 +24,22 @@ pub struct PostgresPayload {
pub id: Option<i64>,
}
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct PresenceMessage {
pub topic: String,
pub event: String,
pub payload: Value,
}
pub fn init(db: PgPool, config: Config) -> (Router, RealtimeState) {
let (tx, _) = broadcast::channel(100);
let (presence_tx, _) = broadcast::channel(100);
let state = RealtimeState {
db,
config,
broadcast_tx: tx,
presence_tx,
presence: Arc::new(DashMap::new()),
};
(ws::router(state.clone()), state)

View File

@@ -4,6 +4,8 @@ use std::sync::Arc;
use crate::PostgresPayload;
// Fallback listener using LISTEN/NOTIFY
// NOTE: Logical Replication implementation was reverted due to missing crate availability.
// Keeping LISTEN/NOTIFY for now to ensure project builds.
pub async fn start_replication_listener(
config: Config,
broadcast_tx: broadcast::Sender<Arc<PostgresPayload>>,

View File

@@ -1,4 +1,4 @@
use crate::PostgresPayload;
use crate::{PostgresPayload, PresenceMessage};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
@@ -10,6 +10,7 @@ use axum::{
Extension, Router,
};
use common::{Config, ProjectContext};
use dashmap::DashMap;
use futures::{sink::SinkExt, stream::StreamExt};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
@@ -18,12 +19,15 @@ use sqlx::PgPool;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
use uuid::Uuid;
#[derive(Clone)]
pub struct RealtimeState {
pub db: PgPool,
pub config: Config,
pub broadcast_tx: broadcast::Sender<Arc<PostgresPayload>>,
pub presence_tx: broadcast::Sender<Arc<PresenceMessage>>,
pub presence: Arc<DashMap<String, DashMap<String, Value>>>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -43,16 +47,15 @@ pub async fn ws_handler(
async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: ProjectContext) {
let (mut ws_sender, mut ws_receiver) = socket.split();
let client_uuid = Uuid::new_v4().to_string();
// Channel for internal tasks to send messages to the websocket client
// We send raw JSON string to avoid struct complexity
let (tx_internal, mut rx_internal) = mpsc::channel::<String>(100);
let mut rx_broadcast = state.broadcast_tx.subscribe();
let mut rx_presence = state.presence_tx.subscribe();
let mut subscriptions = HashSet::<String>::new();
// We might store the user's role/claims if they authenticate
let mut _user_claims: Option<Claims> = None;
loop {
@@ -62,26 +65,22 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
match res {
Ok(msg_arc) => {
let pg_payload = msg_arc.as_ref();
tracing::debug!("Received broadcast for {}.{}", pg_payload.schema, pg_payload.table);
let topic = format!("realtime:{}:{}", pg_payload.schema, pg_payload.table);
let wildcard_topic = format!("realtime:{}:*", pg_payload.schema);
let global_topic = "realtime:*".to_string();
if subscriptions.contains(&topic) || subscriptions.contains(&wildcard_topic) || subscriptions.contains(&global_topic) {
tracing::debug!("Match found for topic: {}", topic);
// Map to Supabase Realtime V2 format
let payload = serde_json::json!({
"schema": pg_payload.schema,
"table": pg_payload.table,
"commit_timestamp": chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
"type": pg_payload.r#type.to_uppercase(),
"event": pg_payload.r#type.to_uppercase(), // For Supabase client fallback
"event": pg_payload.r#type.to_uppercase(),
"new": pg_payload.record,
"old": pg_payload.old_record,
"errors": Option::<String>::None
});
// Phoenix V2 Message: [null, null, topic, "postgres_changes", payload]
let msg_arr = serde_json::json!([
Value::Null,
Value::Null,
@@ -91,24 +90,43 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
]);
if let Ok(json) = serde_json::to_string(&msg_arr) {
tracing::debug!("Sending to client: {}", json);
if ws_sender.send(Message::Text(json)).await.is_err() {
break;
}
}
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {
tracing::warn!("Realtime broadcast lagged");
continue;
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
// 2. Handle internal messages
// 2. Handle incoming presence messages
res = rx_presence.recv() => {
match res {
Ok(msg_arc) => {
let presence_msg = msg_arc.as_ref();
if subscriptions.contains(&presence_msg.topic) {
let msg_arr = serde_json::json!([
Value::Null,
Value::Null,
presence_msg.topic,
"presence_diff", // Supabase expects presence_diff
presence_msg.payload
]);
if let Ok(json) = serde_json::to_string(&msg_arr) {
if ws_sender.send(Message::Text(json)).await.is_err() {
break;
}
}
}
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
// 3. Handle internal messages
msg = rx_internal.recv() => {
match msg {
Some(msg) => {
@@ -116,15 +134,14 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
break;
}
}
None => break, // Channel closed
None => break,
}
}
// 3. Handle incoming messages from Client
// 4. Handle incoming messages from Client
result = ws_receiver.next() => {
match result {
Some(Ok(Message::Text(text))) => {
// Parse Phoenix V2 Array
if let Ok(arr) = serde_json::from_str::<Vec<Value>>(&text) {
if arr.len() >= 4 {
let join_ref = arr.get(0).and_then(|v| v.as_str()).map(|s| s.to_string());
@@ -140,19 +157,14 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
if let Some(jwt) = token {
let validation = Validation::new(Algorithm::HS256);
match decode::<Claims>(jwt, &DecodingKey::from_secret(project_ctx.jwt_secret.as_bytes()), &validation) {
Ok(data) => {
_user_claims = Some(data.claims);
},
Err(_) => {
tracing::warn!("Invalid JWT in join");
}
Ok(data) => { _user_claims = Some(data.claims); },
Err(_) => { tracing::warn!("Invalid JWT in join"); }
}
}
tracing::debug!("Client joined: {}", topic);
subscriptions.insert(topic.clone());
// Send Ack: [join_ref, ref, topic, "phx_reply", {status: "ok", response: {}}]
// Send Ack
let reply = serde_json::json!([
join_ref,
r#ref,
@@ -160,13 +172,73 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
"phx_reply",
{ "status": "ok", "response": {} }
]);
if let Ok(reply_str) = serde_json::to_string(&reply) {
let _ = tx_internal.send(reply_str).await;
let _ = tx_internal.send(reply.to_string()).await;
// Send initial presence state if any
if let Some(topic_presence) = state.presence.get(&topic) {
let mut presence_state = serde_json::Map::new();
for r in topic_presence.iter() {
presence_state.insert(r.key().clone(), serde_json::json!({ "metas": [r.value()] }));
}
let presence_msg = serde_json::json!([
Value::Null,
Value::Null,
topic,
"presence_state",
presence_state
]);
let _ = tx_internal.send(presence_msg.to_string()).await;
}
// Resume logic (omitted for brevity, assume existing implementation works or is merged)
// Keeping resume logic from previous version
let last_event_id = payload.get("last_event_id")
.or_else(|| payload.get("config").and_then(|c| c.get("last_event_id")))
.and_then(|v| v.as_i64());
if let Some(last_id) = last_event_id {
let missed = sqlx::query_as::<_, (i64, serde_json::Value)>(
"SELECT id, payload FROM madbase_realtime.messages WHERE topic = $1 AND id > $2 ORDER BY id ASC"
)
.bind(&topic)
.bind(last_id)
.fetch_all(&state.db)
.await;
if let Ok(messages) = missed {
for (_id, pl) in messages {
let msg_arr = serde_json::json!([
Value::Null,
Value::Null,
topic,
"postgres_changes",
pl
]);
let _ = tx_internal.send(msg_arr.to_string()).await;
}
}
}
},
"phx_leave" => {
tracing::debug!("Client left: {}", topic);
subscriptions.remove(&topic);
// Remove presence
if let Some(topic_presence) = state.presence.get(&topic) {
if let Some((_, old_state)) = topic_presence.remove(&client_uuid) {
// Broadcast leave
let mut leaves = serde_json::Map::new();
leaves.insert(client_uuid.clone(), serde_json::json!({ "metas": [old_state] }));
let diff = serde_json::json!({
"joins": {},
"leaves": leaves
});
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
topic: topic.clone(),
event: "presence_diff".into(),
payload: diff
}));
}
}
let reply = serde_json::json!([
join_ref,
@@ -175,8 +247,40 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
"phx_reply",
{ "status": "ok", "response": {} }
]);
if let Ok(reply_str) = serde_json::to_string(&reply) {
let _ = tx_internal.send(reply_str).await;
let _ = tx_internal.send(reply.to_string()).await;
},
"presence" => {
// Handle track/untrack
// payload: { type: "track", event: "track", payload: { ... } }
// Supabase JS sends: { event: "track", payload: { ... } } inside the payload arg of this match
// The outer payload is the 5th element of the array.
// Inside that payload, there is an "event" field.
let sub_event = payload.get("event").and_then(|v| v.as_str()).unwrap_or("");
if sub_event == "track" {
let state_payload = payload.get("payload").cloned().unwrap_or(Value::Null);
// Add phx_ref
let mut state_obj = state_payload.as_object().cloned().unwrap_or_default();
state_obj.insert("phx_ref".to_string(), Value::String(r#ref.clone().unwrap_or_default()));
let new_state = Value::Object(state_obj);
// Update Store
state.presence.entry(topic.clone()).or_insert_with(DashMap::new).insert(client_uuid.clone(), new_state.clone());
// Broadcast Join
let mut joins = serde_json::Map::new();
joins.insert(client_uuid.clone(), serde_json::json!({ "metas": [new_state] }));
let diff = serde_json::json!({
"joins": joins,
"leaves": {}
});
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
topic: topic.clone(),
event: "presence_diff".into(),
payload: diff
}));
}
},
"heartbeat" => {
@@ -187,27 +291,42 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
"phx_reply",
{ "status": "ok", "response": {} }
]);
if let Ok(reply_str) = serde_json::to_string(&reply) {
let _ = tx_internal.send(reply_str).await;
}
let _ = tx_internal.send(reply.to_string()).await;
},
_ => {
tracing::debug!("Unknown event: {}", event);
}
_ => {}
}
}
} else {
tracing::warn!("Failed to deserialize client message: {}", text);
}
},
Some(Ok(Message::Close(_))) => break,
Some(Err(_)) => break,
None => break, // Stream closed
None => break,
_ => {}
}
}
}
}
// Cleanup on disconnect
for topic in subscriptions {
if let Some(topic_presence) = state.presence.get(&topic) {
if let Some((_, old_state)) = topic_presence.remove(&client_uuid) {
// Broadcast leave
let mut leaves = serde_json::Map::new();
leaves.insert(client_uuid.clone(), serde_json::json!({ "metas": [old_state] }));
let diff = serde_json::json!({
"joins": {},
"leaves": leaves
});
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
topic: topic.clone(),
event: "presence_diff".into(),
payload: diff
}));
}
}
}
}
async fn log_realtime(req: Request, next: Next) -> Response {