Skip to content

Commit

Permalink
feat(torii): add SSE support for mcp
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Feb 28, 2025
1 parent 60e0a6b commit 1724395
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 16 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/torii/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ async-trait = "0.1.83"
tokio-tungstenite = "0.20.0"
hyper-tungstenite = "0.11.1"
futures-util.workspace = true
tokio-stream = "0.1.17"
83 changes: 69 additions & 14 deletions crates/torii/server/src/handlers/mcp.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::sync::Arc;

use futures_util::{SinkExt, StreamExt};
use http::Method;
use hyper::{Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use sqlx::{Row, SqlitePool};
use tokio_tungstenite::tungstenite::Message;
use futures_util::stream::{self, Stream};
use hyper::header;
use std::pin::Pin;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;

use super::sql::map_row_to_json;
use super::Handler;
Expand Down Expand Up @@ -78,11 +84,15 @@ struct ResourceCapabilities {
#[derive(Clone)]
pub struct McpHandler {
pool: Arc<SqlitePool>,
sse_tx: Arc<tokio::sync::Mutex<Vec<mpsc::Sender<String>>>>,
}

impl McpHandler {
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
Self {
pool,
sse_tx: Arc::new(tokio::sync::Mutex::new(Vec::new())),
}
}

async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
Expand Down Expand Up @@ -348,6 +358,41 @@ impl McpHandler {
}
}
}

async fn handle_sse_connection(&self) -> impl Stream<Item = Result<String, std::io::Error>> {
let (tx, rx) = mpsc::channel(100);

// Store the sender for broadcasting events
{
let mut senders = self.sse_tx.lock().await;
senders.push(tx);
}

// Create a stream that first sends an initial message, then forwards from the channel
let initial_message = Ok("event: connected\ndata: {\"status\":\"connected\"}\n\n".to_string());
let initial_stream = stream::once(async move { initial_message });

// Convert the receiver to a stream and handle errors
let receiver_stream = ReceiverStream::new(rx).map(Ok);

// Combine the initial message with the ongoing stream
Box::pin(initial_stream.chain(receiver_stream)) as Pin<Box<dyn Stream<Item = Result<String, std::io::Error>> + Send>>
}

// Method to broadcast an event to all SSE clients
pub async fn broadcast_event(&self, event_type: &str, data: &str) {
let message = format!("event: {}\ndata: {}\n\n", event_type, data);
let mut senders = self.sse_tx.lock().await;

// Send to all clients, removing any that have been closed
senders.retain_mut(|tx| {
match tx.try_send(message.clone()) {
Ok(_) => true,
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => false,
Err(_) => true, // Keep if channel is full but not closed
}
});
}
}

impl JsonRpcResponse {
Expand Down Expand Up @@ -385,16 +430,12 @@ impl JsonRpcResponse {
impl Handler for McpHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/mcp")
&& req
.headers()
.get("upgrade")
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if hyper_tungstenite::is_upgrade_request(&req) {
// Handle WebSocket connections at /mcp/ws
if req.uri().path() == "/mcp/ws" &&
hyper_tungstenite::is_upgrade_request(&req) {
let (response, websocket) = hyper_tungstenite::upgrade(req, None)
.expect("Failed to upgrade WebSocket connection");

Expand All @@ -405,12 +446,26 @@ impl Handler for McpHandler {
}
});

response
} else {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Not a WebSocket upgrade request"))
.unwrap()
return response;
}
// Handle SSE connections at /mcp or /mcp/sse
else if (req.uri().path() == "/mcp" || req.uri().path() == "/mcp/sse") &&
req.method() == Method::GET {
let stream = self.handle_sse_connection().await;

return Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.body(Body::wrap_stream(stream))
.unwrap();
}

// Default response for other requests
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid MCP request. Use /mcp for SSE or /mcp/ws for WebSocket connections."))
.unwrap()
}
}

0 comments on commit 1724395

Please sign in to comment.