Skip to content

Commit ca3f400

Browse files
authored
Use More Semantic HTTP (payjoin#346)
Close payjoin#345
2 parents cc97788 + 11353f9 commit ca3f400

File tree

9 files changed

+129
-82
lines changed

9 files changed

+129
-82
lines changed

Cargo.lock

+9-57
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

payjoin-cli/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ clap = { version = "~4.0.32", features = ["derive"] }
3434
config = "0.13.3"
3535
env_logger = "0.9.0"
3636
http-body-util = { version = "0.1", optional = true }
37-
hyper = { version = "1", features = ["full"], optional = true }
37+
hyper = { version = "1", features = ["http1", "server"], optional = true }
3838
hyper-rustls = { version = "0.26", optional = true }
3939
hyper-util = { version = "0.1", optional = true }
4040
log = "0.4.7"

payjoin-cli/tests/e2e.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,14 @@ mod e2e {
457457
let db = docker.run(Redis::default());
458458
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
459459
println!("Database running on {}", db.get_host_port_ipv4(6379));
460-
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
460+
payjoin_directory::listen_tcp_with_tls(
461+
format!("http://localhost:{}", port),
462+
port,
463+
db_host,
464+
timeout,
465+
local_cert_key,
466+
)
467+
.await
461468
}
462469

463470
// generates or gets a DER encoded localhost cert and key.

payjoin-directory/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ bitcoin = { version = "0.32.2", features = ["base64"] }
2222
bhttp = { version = "=0.5.1", features = ["http"] }
2323
futures = "0.3.17"
2424
http-body-util = "0.1.2"
25-
hyper = { version = "1" }
25+
hyper = { version = "1", features = ["http1", "server"] }
2626
hyper-rustls = { version = "0.26", optional = true }
2727
hyper-util = "0.1"
2828
ohttp = "0.5.1"

payjoin-directory/src/lib.rs

+53-10
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use bitcoin::base64::Engine;
88
use http_body_util::combinators::BoxBody;
99
use http_body_util::{BodyExt, Empty, Full};
1010
use hyper::body::{Body, Bytes, Incoming};
11-
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
11+
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION};
1212
use hyper::server::conn::http1;
1313
use hyper::service::service_fn;
1414
use hyper::{Method, Request, Response, StatusCode, Uri};
@@ -20,6 +20,7 @@ use tracing::{debug, error, info, trace};
2020
pub const DEFAULT_DIR_PORT: u16 = 8080;
2121
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
2222
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
23+
pub const DEFAULT_BASE_URL: &str = "https://localhost";
2324

2425
const MAX_BUFFER_SIZE: usize = 65536;
2526

@@ -31,6 +32,7 @@ mod db;
3132
use crate::db::DbPool;
3233

3334
pub async fn listen_tcp(
35+
base_url: String,
3436
port: u16,
3537
db_host: String,
3638
timeout: Duration,
@@ -42,13 +44,14 @@ pub async fn listen_tcp(
4244
while let Ok((stream, _)) = listener.accept().await {
4345
let pool = pool.clone();
4446
let ohttp = ohttp.clone();
47+
let base_url = base_url.clone();
4548
let io = TokioIo::new(stream);
4649
tokio::spawn(async move {
4750
if let Err(err) = http1::Builder::new()
4851
.serve_connection(
4952
io,
5053
service_fn(move |req| {
51-
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
54+
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
5255
}),
5356
)
5457
.with_upgrades()
@@ -64,6 +67,7 @@ pub async fn listen_tcp(
6467

6568
#[cfg(feature = "danger-local-https")]
6669
pub async fn listen_tcp_with_tls(
70+
base_url: String,
6771
port: u16,
6872
db_host: String,
6973
timeout: Duration,
@@ -77,6 +81,7 @@ pub async fn listen_tcp_with_tls(
7781
while let Ok((stream, _)) = listener.accept().await {
7882
let pool = pool.clone();
7983
let ohttp = ohttp.clone();
84+
let base_url = base_url.clone();
8085
let tls_acceptor = tls_acceptor.clone();
8186
tokio::spawn(async move {
8287
let tls_stream = match tls_acceptor.accept(stream).await {
@@ -90,7 +95,7 @@ pub async fn listen_tcp_with_tls(
9095
.serve_connection(
9196
TokioIo::new(tls_stream),
9297
service_fn(move |req| {
93-
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
98+
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
9499
}),
95100
)
96101
.with_upgrades()
@@ -143,6 +148,7 @@ async fn serve_payjoin_directory(
143148
req: Request<Incoming>,
144149
pool: DbPool,
145150
ohttp: Arc<Mutex<ohttp::Server>>,
151+
base_url: String,
146152
) -> Result<Response<BoxBody<Bytes, hyper::Error>>> {
147153
let path = req.uri().path().to_string();
148154
let query = req.uri().query().unwrap_or_default().to_string();
@@ -151,7 +157,7 @@ async fn serve_payjoin_directory(
151157
let path_segments: Vec<&str> = path.split('/').collect();
152158
debug!("serve_payjoin_directory: {:?}", &path_segments);
153159
let mut response = match (parts.method, path_segments.as_slice()) {
154-
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await,
160+
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp, base_url).await,
155161
(Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await,
156162
(Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await,
157163
(Method::GET, ["", "health"]) => health_check().await,
@@ -169,6 +175,7 @@ async fn handle_ohttp_gateway(
169175
body: Incoming,
170176
pool: DbPool,
171177
ohttp: Arc<Mutex<ohttp::Server>>,
178+
base_url: String,
172179
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
173180
// decapsulate
174181
let ohttp_body =
@@ -194,10 +201,13 @@ async fn handle_ohttp_gateway(
194201
}
195202
let request = http_req.body(full(body))?;
196203

197-
let response = handle_v2(pool, request).await?;
204+
let response = handle_v2(pool, base_url, request).await?;
198205

199206
let (parts, body) = response.into_parts();
200207
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
208+
for (name, value) in parts.headers.iter() {
209+
bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default());
210+
}
201211
let full_body =
202212
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
203213
bhttp_res.write_content(&full_body);
@@ -213,6 +223,7 @@ async fn handle_ohttp_gateway(
213223

214224
async fn handle_v2(
215225
pool: DbPool,
226+
base_url: String,
216227
req: Request<BoxBody<Bytes, hyper::Error>>,
217228
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
218229
let path = req.uri().path().to_string();
@@ -221,10 +232,10 @@ async fn handle_v2(
221232
let path_segments: Vec<&str> = path.split('/').collect();
222233
debug!("handle_v2: {:?}", &path_segments);
223234
match (parts.method, path_segments.as_slice()) {
224-
(Method::POST, &["", ""]) => post_session(body).await,
235+
(Method::POST, &["", ""]) => post_session(base_url, body).await,
225236
(Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await,
226237
(Method::GET, &["", id]) => get_fallback(id, pool).await,
227-
(Method::POST, &["", id, "payjoin"]) => post_payjoin(id, body, pool).await,
238+
(Method::PUT, &["", id]) => post_payjoin(id, body, pool).await,
228239
_ => Ok(not_found()),
229240
}
230241
}
@@ -233,6 +244,7 @@ async fn health_check() -> Result<Response<BoxBody<Bytes, hyper::Error>>, Handle
233244
Ok(Response::new(empty()))
234245
}
235246

247+
#[derive(Debug)]
236248
enum HandlerError {
237249
PayloadTooLarge,
238250
InternalServerError(anyhow::Error),
@@ -273,6 +285,7 @@ impl From<hyper::http::Error> for HandlerError {
273285
}
274286

275287
async fn post_session(
288+
base_url: String,
276289
body: BoxBody<Bytes, hyper::Error>,
277290
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
278291
let bytes = body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes();
@@ -283,9 +296,10 @@ async fn post_session(
283296
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
284297
.map_err(|e| HandlerError::BadRequest(e.into()))?;
285298
tracing::info!("Initialized session with pubkey: {:?}", pubkey);
286-
let mut res = Response::new(empty());
287-
*res.status_mut() = StatusCode::NO_CONTENT;
288-
Ok(res)
299+
Ok(Response::builder()
300+
.header(LOCATION, format!("{}/{}", base_url, pubkey))
301+
.status(StatusCode::CREATED)
302+
.body(empty())?)
289303
}
290304

291305
async fn post_fallback_v1(
@@ -413,3 +427,32 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
413427
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
414428
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
415429
}
430+
431+
#[cfg(test)]
432+
mod tests {
433+
use hyper::Request;
434+
435+
use super::*;
436+
437+
/// Ensure that the POST / endpoint returns a 201 Created with a Location header
438+
/// as is semantically correct when creating a resource.
439+
///
440+
/// https://datatracker.ietf.org/doc/html/rfc9110#name-post
441+
#[tokio::test]
442+
async fn test_post_session() -> Result<(), Box<dyn std::error::Error>> {
443+
let base_url = "https://localhost".to_string();
444+
let body = full("some_base64_encoded_pubkey");
445+
446+
let request = Request::builder().method(Method::POST).uri("/").body(body)?;
447+
448+
let response = post_session(base_url.clone(), request.into_body())
449+
.await
450+
.map_err(|e| format!("{:?}", e))?;
451+
452+
assert_eq!(response.status(), StatusCode::CREATED);
453+
assert!(response.headers().contains_key(LOCATION));
454+
let location_header = response.headers().get(LOCATION).ok_or("Missing LOCATION header")?;
455+
assert!(location_header.to_str()?.starts_with(&base_url));
456+
Ok(())
457+
}
458+
}

payjoin-directory/src/main.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1717

1818
let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string());
1919

20-
payjoin_directory::listen_tcp(dir_port, db_host, timeout).await
20+
let base_url = env::var("PJ_DIR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
21+
22+
payjoin_directory::listen_tcp(base_url, dir_port, db_host, timeout).await
2123
}
2224

2325
fn init_logging() {

0 commit comments

Comments
 (0)