Skip to content

Commit

Permalink
use rwlock
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Mar 3, 2025
1 parent 18dd3ee commit b8a969e
Showing 1 changed file with 81 additions and 98 deletions.
179 changes: 81 additions & 98 deletions crates/torii/server/src/handlers/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use hyper::{Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use serde_json::{json, Number, Value};
use sqlx::{Row, SqlitePool};
use tokio::sync::broadcast;
use tokio::sync::{broadcast, RwLock};
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;

Expand Down Expand Up @@ -88,14 +88,14 @@ struct SseSession {
pub struct McpHandler {
pool: Arc<SqlitePool>,
// Map of session IDs to SSE sessions
sse_sessions: Arc<tokio::sync::Mutex<std::collections::HashMap<String, SseSession>>>,
sse_sessions: Arc<RwLock<std::collections::HashMap<String, SseSession>>>,
}

impl McpHandler {
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self {
Self {
pool,
sse_sessions: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
sse_sessions: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
}
}

Expand Down Expand Up @@ -216,34 +216,28 @@ impl McpHandler {
}

// New method to handle SSE connections
async fn handle_sse_connection(
&self,
response_builder: hyper::http::response::Builder,
) -> Response<Body> {
async fn handle_sse_connection(&self) -> Response<Body> {
// Create a new session ID
let session_id = Uuid::new_v4().to_string();

// Create a broadcast channel for SSE messages
let (tx, rx) = broadcast::channel::<JsonRpcResponse>(SSE_CHANNEL_CAPACITY);

// Store the session
{
let mut sessions = self.sse_sessions.lock().await;
sessions.insert(session_id.clone(), SseSession { tx: tx.clone(), session_id: session_id.clone() });
let mut sessions = self.sse_sessions.write().await;
sessions.insert(
session_id.clone(),
SseSession { tx: tx.clone(), session_id: session_id.clone() },
);
}

// Create the message endpoint path
let message_endpoint = format!("/mcp/message?sessionId={}", session_id);

// Create initial endpoint info event - using full URL format
let endpoint_info = format!(
"event: endpoint\ndata: {}\n\n",
message_endpoint
);

// Log the endpoint creation
eprintln!("Created SSE session {} with endpoint {}", session_id, message_endpoint);

let endpoint_info = format!("event: endpoint\ndata: {}\n\n", message_endpoint);

// Create the streaming body with the endpoint information followed by event data
let stream = futures_util::stream::once(async move {
Ok::<_, hyper::Error>(hyper::body::Bytes::from(endpoint_info))
Expand All @@ -256,76 +250,64 @@ impl McpHandler {
Ok(json) => {
// Format SSE data with event name and proper line breaks
let sse_data = format!("event: message\ndata: {}\n\n", json);
Some((Ok::<_, hyper::Error>(hyper::body::Bytes::from(sse_data)), rx))
},
Some((
Ok::<_, hyper::Error>(hyper::body::Bytes::from(sse_data)),
rx,
))
}
Err(e) => {
eprintln!("Error serializing message: {}", e);
// Format error event with proper SSE format
Some((Ok::<_, hyper::Error>(hyper::body::Bytes::from(
format!("event: error\ndata: {{\n \"error\": \"{}\" }}\n\n", e)
)), rx))
Some((
Ok::<_, hyper::Error>(hyper::body::Bytes::from(format!(
"event: error\ndata: {{\n \"error\": \"{}\" }}\n\n",
e
))),
rx,
))
}
}
},
}
Err(_) => None,
}
}
}));

// Return the SSE response
response_builder
Response::builder()
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.header("Access-Control-Allow-Origin", "*")
.header("X-Session-Id", session_id)
.body(Body::wrap_stream(stream))
.unwrap()
}

// New method to handle JSON-RPC messages sent via HTTP POST
async fn handle_message_request(&self, req: Request<Body>) -> Response<Body> {
// Extract the session ID from the query parameters
let uri = req.uri();
let query = uri.query().unwrap_or("");

// Naively parse the session ID (in a real implementation, use a proper URL parser)
let mut session_id = None;
for pair in query.split('&') {
let mut parts = pair.split('=');
if let Some(key) = parts.next() {
if key == "sessionId" {
session_id = parts.next();
break;
}
}
}

let session_id = match session_id {
Some(id) => id,
None => {
// Return an error if no session ID was provided
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Missing sessionId parameter"))
.unwrap();
}
};

let session_id = uri.query().unwrap().split("=").collect::<Vec<_>>()[1];

// Check if the session exists
let tx = {
let sessions = self.sse_sessions.lock().await;
let sessions = self.sse_sessions.read().await;
match sessions.get(session_id) {
Some(session) => session.tx.clone(),
None => {
Some(s) => s.tx.clone(),
_ => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from(format!("Session {} not found", session_id)))
.unwrap();
.body(Body::from(
serde_json::to_string(&JsonRpcResponse::invalid_params(
Value::Number(Into::<Number>::into(-1)),
"missing sessionId query param",
))
.unwrap(),
))
.unwrap()
}
}
};

// Read the request body
let body_bytes = match hyper::body::to_bytes(req.into_body()).await {
Ok(bytes) => bytes,
Expand All @@ -336,7 +318,7 @@ impl McpHandler {
.unwrap();
}
};

let body_str = match String::from_utf8(body_bytes.to_vec()) {
Ok(s) => s,
Err(e) => {
Expand All @@ -346,62 +328,68 @@ impl McpHandler {
.unwrap();
}
};

// First try to parse as a raw JSON value to handle any valid JSON input
let parsed_json: Result<serde_json::Value, _> = serde_json::from_str(&body_str);

let response = match parsed_json {
Ok(json_value) => {
// Try to parse as a JsonRpcMessage
match serde_json::from_value::<JsonRpcMessage>(json_value.clone()) {
Ok(JsonRpcMessage::Request(request)) => {
let response = self.handle_request(request).await;

// Forward the response to the SSE channel
if let Err(e) = tx.send(response.clone()) {
eprintln!("Error forwarding response to SSE: {}", e);
} else {
eprintln!("Successfully sent response to SSE channel: {:?}", response.id);
eprintln!(
"Successfully sent response to SSE channel: {:?}",
response.id
);
}

Response::builder()
.status(StatusCode::ACCEPTED)
.header("Content-Type", "application/json")
.header("Access-Control-Allow-Origin", "*")
.body(Body::from(serde_json::to_string(&response).unwrap()))
.unwrap()
},
}
Ok(JsonRpcMessage::Notification(_)) => {
// For notifications, just send 202 Accepted with no body
Response::builder()
.status(StatusCode::ACCEPTED)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Origin", "*")
.body(Body::empty())
.unwrap()
},
}
Err(_) => {
// If not a valid JsonRpcMessage, try to interpret as a raw request
// This is more permissive and handles cases where the client sends simplified JSON
if let Some(method) = json_value.get("method").and_then(|m| m.as_str()) {
let id = json_value.get("id").cloned().unwrap_or(Value::Null);
let params = json_value.get("params").cloned();

let request = JsonRpcRequest {
jsonrpc: JSONRPC_VERSION.to_string(),
id,
method: method.to_string(),
params,
};

let response = self.handle_request(request).await;

// Forward the response to the SSE channel
if let Err(e) = tx.send(response.clone()) {
eprintln!("Error forwarding response to SSE: {}", e);
} else {
eprintln!("Successfully sent response to SSE channel: {:?}", response.id);
eprintln!(
"Successfully sent response to SSE channel: {:?}",
response.id
);
}

Response::builder()
.status(StatusCode::ACCEPTED)
.header("Content-Type", "application/json")
Expand All @@ -420,7 +408,7 @@ impl McpHandler {
}
}
}
},
}
Err(e) => {
let error_response = JsonRpcResponse::parse_error(Value::Null, &e.to_string());
Response::builder()
Expand All @@ -431,7 +419,7 @@ impl McpHandler {
.unwrap()
}
};

response
}

Expand Down Expand Up @@ -623,7 +611,7 @@ impl Handler for McpHandler {

async fn handle(&self, req: Request<Body>) -> Response<Body> {
let uri_path = req.uri().path();

// Handle CORS preflight requests
if req.method() == hyper::Method::OPTIONS {
return Response::builder()
Expand All @@ -635,20 +623,17 @@ impl Handler for McpHandler {
.body(Body::empty())
.unwrap();
}

// Handle message endpoint (for SSE clients)
if uri_path == "/mcp/message" {
return self.handle_message_request(req).await;
}

match req.method() {
// Handle GET requests for SSE connection
&hyper::Method::GET => {
return self.handle_sse_connection(
Response::builder()
.header("Access-Control-Allow-Origin", "*")
).await;
},
return self.handle_sse_connection().await;
}
// Handle WebSocket upgrade requests
_ if hyper_tungstenite::is_upgrade_request(&req) => {
let (response, websocket) = match hyper_tungstenite::upgrade(req, None) {
Expand All @@ -670,15 +655,13 @@ impl Handler for McpHandler {
});

response
},
// Return Method Not Allowed for other methods
_ => {
Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header("Access-Control-Allow-Origin", "*")
.body(Body::from("Method not allowed"))
.unwrap()
}
// Return Method Not Allowed for other methods
_ => Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header("Access-Control-Allow-Origin", "*")
.body(Body::from("Method not allowed"))
.unwrap(),
}
}
}
}

0 comments on commit b8a969e

Please sign in to comment.