diff --git a/Cargo.lock b/Cargo.lock index 11649f9882..75712b9fc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14609,9 +14609,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" dependencies = [ "futures-core", "pin-project-lite", @@ -15134,6 +15134,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", + "tokio-stream", "tokio-tungstenite 0.20.1", "tokio-util", "torii-sqlite", diff --git a/crates/torii/server/Cargo.toml b/crates/torii/server/Cargo.toml index 9d01849f4f..06227103a5 100644 --- a/crates/torii/server/Cargo.toml +++ b/crates/torii/server/Cargo.toml @@ -34,3 +34,4 @@ async-trait = "0.1.83" tokio-tungstenite = "0.20.0" hyper-tungstenite = "0.11.1" futures-util.workspace = true +tokio-stream = "0.1.17" diff --git a/crates/torii/server/src/handlers/mcp.rs b/crates/torii/server/src/handlers/mcp.rs index 22c9e4a64b..90c81caf35 100644 --- a/crates/torii/server/src/handlers/mcp.rs +++ b/crates/torii/server/src/handlers/mcp.rs @@ -1,11 +1,17 @@ use std::sync::Arc; use futures_util::{SinkExt, StreamExt}; +use http::Method; use hyper::{Body, Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use sqlx::{Row, SqlitePool}; use tokio_tungstenite::tungstenite::Message; +use futures_util::stream::{self, Stream}; +use hyper::header; +use std::pin::Pin; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; use super::sql::map_row_to_json; use super::Handler; @@ -78,11 +84,15 @@ struct ResourceCapabilities { #[derive(Clone)] pub struct McpHandler { pool: Arc, + sse_tx: Arc>>>, } impl McpHandler { pub fn new(pool: Arc) -> Self { - Self { pool } + Self { + pool, + sse_tx: Arc::new(tokio::sync::Mutex::new(Vec::new())), + } } async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse { @@ -348,6 +358,41 @@ impl McpHandler { } } } + + async fn handle_sse_connection(&self) -> impl Stream> { + let (tx, rx) = mpsc::channel(100); + + // Store the sender for broadcasting events + { + let mut senders = self.sse_tx.lock().await; + senders.push(tx); + } + + // Create a stream that first sends an initial message, then forwards from the channel + let initial_message = Ok("event: connected\ndata: {\"status\":\"connected\"}\n\n".to_string()); + let initial_stream = stream::once(async move { initial_message }); + + // Convert the receiver to a stream and handle errors + let receiver_stream = ReceiverStream::new(rx).map(Ok); + + // Combine the initial message with the ongoing stream + Box::pin(initial_stream.chain(receiver_stream)) as Pin> + Send>> + } + + // Method to broadcast an event to all SSE clients + pub async fn broadcast_event(&self, event_type: &str, data: &str) { + let message = format!("event: {}\ndata: {}\n\n", event_type, data); + let mut senders = self.sse_tx.lock().await; + + // Send to all clients, removing any that have been closed + senders.retain_mut(|tx| { + match tx.try_send(message.clone()) { + Ok(_) => true, + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false, + Err(_) => true, // Keep if channel is full but not closed + } + }); + } } impl JsonRpcResponse { @@ -385,16 +430,12 @@ impl JsonRpcResponse { impl Handler for McpHandler { fn should_handle(&self, req: &Request) -> bool { req.uri().path().starts_with("/mcp") - && req - .headers() - .get("upgrade") - .and_then(|h| h.to_str().ok()) - .map(|h| h.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) } async fn handle(&self, req: Request) -> Response { - if hyper_tungstenite::is_upgrade_request(&req) { + // Handle WebSocket connections at /mcp/ws + if req.uri().path() == "/mcp/ws" && + hyper_tungstenite::is_upgrade_request(&req) { let (response, websocket) = hyper_tungstenite::upgrade(req, None) .expect("Failed to upgrade WebSocket connection"); @@ -405,12 +446,26 @@ impl Handler for McpHandler { } }); - response - } else { - Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("Not a WebSocket upgrade request")) - .unwrap() + return response; + } + // Handle SSE connections at /mcp or /mcp/sse + else if (req.uri().path() == "/mcp" || req.uri().path() == "/mcp/sse") && + req.method() == Method::GET { + let stream = self.handle_sse_connection().await; + + return Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "text/event-stream") + .header(header::CACHE_CONTROL, "no-cache") + .header(header::CONNECTION, "keep-alive") + .body(Body::wrap_stream(stream)) + .unwrap(); } + + // Default response for other requests + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid MCP request. Use /mcp for SSE or /mcp/ws for WebSocket connections.")) + .unwrap() } }