added more support for supabase-js
This commit is contained in:
@@ -14,9 +14,11 @@ sqlx = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres = { version = "0.7", features = ["array-impls", "with-uuid-1", "with-serde_json-1", "with-chrono-0_4"] }
|
||||
postgres-types = "0.2"
|
||||
postgres-protocol = "0.6"
|
||||
anyhow = { workspace = true }
|
||||
bytes = "1.0"
|
||||
jsonwebtoken = { workspace = true }
|
||||
chrono.workspace = true
|
||||
dashmap = "5.5"
|
||||
|
||||
@@ -3,9 +3,11 @@ pub mod ws;
|
||||
|
||||
use axum::Router;
|
||||
use common::Config;
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
pub use ws::{router, RealtimeState};
|
||||
|
||||
@@ -22,12 +24,22 @@ pub struct PostgresPayload {
|
||||
pub id: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct PresenceMessage {
|
||||
pub topic: String,
|
||||
pub event: String,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
pub fn init(db: PgPool, config: Config) -> (Router, RealtimeState) {
|
||||
let (tx, _) = broadcast::channel(100);
|
||||
let (presence_tx, _) = broadcast::channel(100);
|
||||
let state = RealtimeState {
|
||||
db,
|
||||
config,
|
||||
broadcast_tx: tx,
|
||||
presence_tx,
|
||||
presence: Arc::new(DashMap::new()),
|
||||
};
|
||||
|
||||
(ws::router(state.clone()), state)
|
||||
|
||||
@@ -4,6 +4,8 @@ use std::sync::Arc;
|
||||
use crate::PostgresPayload;
|
||||
|
||||
// Fallback listener using LISTEN/NOTIFY
|
||||
// NOTE: Logical Replication implementation was reverted due to missing crate availability.
|
||||
// Keeping LISTEN/NOTIFY for now to ensure project builds.
|
||||
pub async fn start_replication_listener(
|
||||
config: Config,
|
||||
broadcast_tx: broadcast::Sender<Arc<PostgresPayload>>,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::PostgresPayload;
|
||||
use crate::{PostgresPayload, PresenceMessage};
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
@@ -10,6 +10,7 @@ use axum::{
|
||||
Extension, Router,
|
||||
};
|
||||
use common::{Config, ProjectContext};
|
||||
use dashmap::DashMap;
|
||||
use futures::{sink::SinkExt, stream::StreamExt};
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -18,12 +19,15 @@ use sqlx::PgPool;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RealtimeState {
|
||||
pub db: PgPool,
|
||||
pub config: Config,
|
||||
pub broadcast_tx: broadcast::Sender<Arc<PostgresPayload>>,
|
||||
pub presence_tx: broadcast::Sender<Arc<PresenceMessage>>,
|
||||
pub presence: Arc<DashMap<String, DashMap<String, Value>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -43,16 +47,15 @@ pub async fn ws_handler(
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: ProjectContext) {
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
let client_uuid = Uuid::new_v4().to_string();
|
||||
|
||||
// Channel for internal tasks to send messages to the websocket client
|
||||
// We send raw JSON string to avoid struct complexity
|
||||
let (tx_internal, mut rx_internal) = mpsc::channel::<String>(100);
|
||||
|
||||
let mut rx_broadcast = state.broadcast_tx.subscribe();
|
||||
let mut rx_presence = state.presence_tx.subscribe();
|
||||
|
||||
let mut subscriptions = HashSet::<String>::new();
|
||||
|
||||
// We might store the user's role/claims if they authenticate
|
||||
let mut _user_claims: Option<Claims> = None;
|
||||
|
||||
loop {
|
||||
@@ -62,26 +65,22 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
match res {
|
||||
Ok(msg_arc) => {
|
||||
let pg_payload = msg_arc.as_ref();
|
||||
tracing::debug!("Received broadcast for {}.{}", pg_payload.schema, pg_payload.table);
|
||||
let topic = format!("realtime:{}:{}", pg_payload.schema, pg_payload.table);
|
||||
let wildcard_topic = format!("realtime:{}:*", pg_payload.schema);
|
||||
let global_topic = "realtime:*".to_string();
|
||||
|
||||
if subscriptions.contains(&topic) || subscriptions.contains(&wildcard_topic) || subscriptions.contains(&global_topic) {
|
||||
tracing::debug!("Match found for topic: {}", topic);
|
||||
// Map to Supabase Realtime V2 format
|
||||
let payload = serde_json::json!({
|
||||
"schema": pg_payload.schema,
|
||||
"table": pg_payload.table,
|
||||
"commit_timestamp": chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
|
||||
"type": pg_payload.r#type.to_uppercase(),
|
||||
"event": pg_payload.r#type.to_uppercase(), // For Supabase client fallback
|
||||
"event": pg_payload.r#type.to_uppercase(),
|
||||
"new": pg_payload.record,
|
||||
"old": pg_payload.old_record,
|
||||
"errors": Option::<String>::None
|
||||
});
|
||||
|
||||
// Phoenix V2 Message: [null, null, topic, "postgres_changes", payload]
|
||||
let msg_arr = serde_json::json!([
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
@@ -91,24 +90,43 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
]);
|
||||
|
||||
if let Ok(json) = serde_json::to_string(&msg_arr) {
|
||||
tracing::debug!("Sending to client: {}", json);
|
||||
if ws_sender.send(Message::Text(json)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
||||
tracing::warn!("Realtime broadcast lagged");
|
||||
continue;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Handle internal messages
|
||||
// 2. Handle incoming presence messages
|
||||
res = rx_presence.recv() => {
|
||||
match res {
|
||||
Ok(msg_arc) => {
|
||||
let presence_msg = msg_arc.as_ref();
|
||||
if subscriptions.contains(&presence_msg.topic) {
|
||||
let msg_arr = serde_json::json!([
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
presence_msg.topic,
|
||||
"presence_diff", // Supabase expects presence_diff
|
||||
presence_msg.payload
|
||||
]);
|
||||
if let Ok(json) = serde_json::to_string(&msg_arr) {
|
||||
if ws_sender.send(Message::Text(json)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle internal messages
|
||||
msg = rx_internal.recv() => {
|
||||
match msg {
|
||||
Some(msg) => {
|
||||
@@ -116,15 +134,14 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => break, // Channel closed
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle incoming messages from Client
|
||||
// 4. Handle incoming messages from Client
|
||||
result = ws_receiver.next() => {
|
||||
match result {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
// Parse Phoenix V2 Array
|
||||
if let Ok(arr) = serde_json::from_str::<Vec<Value>>(&text) {
|
||||
if arr.len() >= 4 {
|
||||
let join_ref = arr.get(0).and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
@@ -140,19 +157,14 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
if let Some(jwt) = token {
|
||||
let validation = Validation::new(Algorithm::HS256);
|
||||
match decode::<Claims>(jwt, &DecodingKey::from_secret(project_ctx.jwt_secret.as_bytes()), &validation) {
|
||||
Ok(data) => {
|
||||
_user_claims = Some(data.claims);
|
||||
},
|
||||
Err(_) => {
|
||||
tracing::warn!("Invalid JWT in join");
|
||||
}
|
||||
Ok(data) => { _user_claims = Some(data.claims); },
|
||||
Err(_) => { tracing::warn!("Invalid JWT in join"); }
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!("Client joined: {}", topic);
|
||||
subscriptions.insert(topic.clone());
|
||||
|
||||
// Send Ack: [join_ref, ref, topic, "phx_reply", {status: "ok", response: {}}]
|
||||
// Send Ack
|
||||
let reply = serde_json::json!([
|
||||
join_ref,
|
||||
r#ref,
|
||||
@@ -160,13 +172,73 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
"phx_reply",
|
||||
{ "status": "ok", "response": {} }
|
||||
]);
|
||||
if let Ok(reply_str) = serde_json::to_string(&reply) {
|
||||
let _ = tx_internal.send(reply_str).await;
|
||||
let _ = tx_internal.send(reply.to_string()).await;
|
||||
|
||||
// Send initial presence state if any
|
||||
if let Some(topic_presence) = state.presence.get(&topic) {
|
||||
let mut presence_state = serde_json::Map::new();
|
||||
for r in topic_presence.iter() {
|
||||
presence_state.insert(r.key().clone(), serde_json::json!({ "metas": [r.value()] }));
|
||||
}
|
||||
let presence_msg = serde_json::json!([
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
topic,
|
||||
"presence_state",
|
||||
presence_state
|
||||
]);
|
||||
let _ = tx_internal.send(presence_msg.to_string()).await;
|
||||
}
|
||||
|
||||
// Resume logic (omitted for brevity, assume existing implementation works or is merged)
|
||||
// Keeping resume logic from previous version
|
||||
let last_event_id = payload.get("last_event_id")
|
||||
.or_else(|| payload.get("config").and_then(|c| c.get("last_event_id")))
|
||||
.and_then(|v| v.as_i64());
|
||||
|
||||
if let Some(last_id) = last_event_id {
|
||||
let missed = sqlx::query_as::<_, (i64, serde_json::Value)>(
|
||||
"SELECT id, payload FROM madbase_realtime.messages WHERE topic = $1 AND id > $2 ORDER BY id ASC"
|
||||
)
|
||||
.bind(&topic)
|
||||
.bind(last_id)
|
||||
.fetch_all(&state.db)
|
||||
.await;
|
||||
|
||||
if let Ok(messages) = missed {
|
||||
for (_id, pl) in messages {
|
||||
let msg_arr = serde_json::json!([
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
topic,
|
||||
"postgres_changes",
|
||||
pl
|
||||
]);
|
||||
let _ = tx_internal.send(msg_arr.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"phx_leave" => {
|
||||
tracing::debug!("Client left: {}", topic);
|
||||
subscriptions.remove(&topic);
|
||||
// Remove presence
|
||||
if let Some(topic_presence) = state.presence.get(&topic) {
|
||||
if let Some((_, old_state)) = topic_presence.remove(&client_uuid) {
|
||||
// Broadcast leave
|
||||
let mut leaves = serde_json::Map::new();
|
||||
leaves.insert(client_uuid.clone(), serde_json::json!({ "metas": [old_state] }));
|
||||
|
||||
let diff = serde_json::json!({
|
||||
"joins": {},
|
||||
"leaves": leaves
|
||||
});
|
||||
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
|
||||
topic: topic.clone(),
|
||||
event: "presence_diff".into(),
|
||||
payload: diff
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
let reply = serde_json::json!([
|
||||
join_ref,
|
||||
@@ -175,8 +247,40 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
"phx_reply",
|
||||
{ "status": "ok", "response": {} }
|
||||
]);
|
||||
if let Ok(reply_str) = serde_json::to_string(&reply) {
|
||||
let _ = tx_internal.send(reply_str).await;
|
||||
let _ = tx_internal.send(reply.to_string()).await;
|
||||
},
|
||||
"presence" => {
|
||||
// Handle track/untrack
|
||||
// payload: { type: "track", event: "track", payload: { ... } }
|
||||
// Supabase JS sends: { event: "track", payload: { ... } } inside the payload arg of this match
|
||||
|
||||
// The outer payload is the 5th element of the array.
|
||||
// Inside that payload, there is an "event" field.
|
||||
let sub_event = payload.get("event").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if sub_event == "track" {
|
||||
let state_payload = payload.get("payload").cloned().unwrap_or(Value::Null);
|
||||
// Add phx_ref
|
||||
let mut state_obj = state_payload.as_object().cloned().unwrap_or_default();
|
||||
state_obj.insert("phx_ref".to_string(), Value::String(r#ref.clone().unwrap_or_default()));
|
||||
let new_state = Value::Object(state_obj);
|
||||
|
||||
// Update Store
|
||||
state.presence.entry(topic.clone()).or_insert_with(DashMap::new).insert(client_uuid.clone(), new_state.clone());
|
||||
|
||||
// Broadcast Join
|
||||
let mut joins = serde_json::Map::new();
|
||||
joins.insert(client_uuid.clone(), serde_json::json!({ "metas": [new_state] }));
|
||||
|
||||
let diff = serde_json::json!({
|
||||
"joins": joins,
|
||||
"leaves": {}
|
||||
});
|
||||
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
|
||||
topic: topic.clone(),
|
||||
event: "presence_diff".into(),
|
||||
payload: diff
|
||||
}));
|
||||
}
|
||||
},
|
||||
"heartbeat" => {
|
||||
@@ -187,27 +291,42 @@ async fn handle_socket(socket: WebSocket, state: RealtimeState, project_ctx: Pro
|
||||
"phx_reply",
|
||||
{ "status": "ok", "response": {} }
|
||||
]);
|
||||
if let Ok(reply_str) = serde_json::to_string(&reply) {
|
||||
let _ = tx_internal.send(reply_str).await;
|
||||
}
|
||||
let _ = tx_internal.send(reply.to_string()).await;
|
||||
},
|
||||
_ => {
|
||||
tracing::debug!("Unknown event: {}", event);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Failed to deserialize client message: {}", text);
|
||||
}
|
||||
},
|
||||
Some(Ok(Message::Close(_))) => break,
|
||||
Some(Err(_)) => break,
|
||||
None => break, // Stream closed
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup on disconnect
|
||||
for topic in subscriptions {
|
||||
if let Some(topic_presence) = state.presence.get(&topic) {
|
||||
if let Some((_, old_state)) = topic_presence.remove(&client_uuid) {
|
||||
// Broadcast leave
|
||||
let mut leaves = serde_json::Map::new();
|
||||
leaves.insert(client_uuid.clone(), serde_json::json!({ "metas": [old_state] }));
|
||||
|
||||
let diff = serde_json::json!({
|
||||
"joins": {},
|
||||
"leaves": leaves
|
||||
});
|
||||
let _ = state.presence_tx.send(Arc::new(PresenceMessage {
|
||||
topic: topic.clone(),
|
||||
event: "presence_diff".into(),
|
||||
payload: diff
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn log_realtime(req: Request, next: Next) -> Response {
|
||||
|
||||
Reference in New Issue
Block a user