@@ -8,7 +8,7 @@ use bitcoin::base64::Engine;
8
8
use http_body_util:: combinators:: BoxBody ;
9
9
use http_body_util:: { BodyExt , Empty , Full } ;
10
10
use hyper:: body:: { Body , Bytes , Incoming } ;
11
- use hyper:: header:: { HeaderValue , ACCESS_CONTROL_ALLOW_ORIGIN , CONTENT_TYPE } ;
11
+ use hyper:: header:: { HeaderValue , ACCESS_CONTROL_ALLOW_ORIGIN , CONTENT_TYPE , LOCATION } ;
12
12
use hyper:: server:: conn:: http1;
13
13
use hyper:: service:: service_fn;
14
14
use hyper:: { Method , Request , Response , StatusCode , Uri } ;
@@ -20,6 +20,7 @@ use tracing::{debug, error, info, trace};
20
20
pub const DEFAULT_DIR_PORT : u16 = 8080 ;
21
21
pub const DEFAULT_DB_HOST : & str = "localhost:6379" ;
22
22
pub const DEFAULT_TIMEOUT_SECS : u64 = 30 ;
23
+ pub const DEFAULT_BASE_URL : & str = "https://localhost" ;
23
24
24
25
const MAX_BUFFER_SIZE : usize = 65536 ;
25
26
@@ -31,6 +32,7 @@ mod db;
31
32
use crate :: db:: DbPool ;
32
33
33
34
pub async fn listen_tcp (
35
+ base_url : String ,
34
36
port : u16 ,
35
37
db_host : String ,
36
38
timeout : Duration ,
@@ -42,13 +44,14 @@ pub async fn listen_tcp(
42
44
while let Ok ( ( stream, _) ) = listener. accept ( ) . await {
43
45
let pool = pool. clone ( ) ;
44
46
let ohttp = ohttp. clone ( ) ;
47
+ let base_url = base_url. clone ( ) ;
45
48
let io = TokioIo :: new ( stream) ;
46
49
tokio:: spawn ( async move {
47
50
if let Err ( err) = http1:: Builder :: new ( )
48
51
. serve_connection (
49
52
io,
50
53
service_fn ( move |req| {
51
- serve_payjoin_directory ( req, pool. clone ( ) , ohttp. clone ( ) )
54
+ serve_payjoin_directory ( req, pool. clone ( ) , ohttp. clone ( ) , base_url . clone ( ) )
52
55
} ) ,
53
56
)
54
57
. with_upgrades ( )
@@ -64,6 +67,7 @@ pub async fn listen_tcp(
64
67
65
68
#[ cfg( feature = "danger-local-https" ) ]
66
69
pub async fn listen_tcp_with_tls (
70
+ base_url : String ,
67
71
port : u16 ,
68
72
db_host : String ,
69
73
timeout : Duration ,
@@ -77,6 +81,7 @@ pub async fn listen_tcp_with_tls(
77
81
while let Ok ( ( stream, _) ) = listener. accept ( ) . await {
78
82
let pool = pool. clone ( ) ;
79
83
let ohttp = ohttp. clone ( ) ;
84
+ let base_url = base_url. clone ( ) ;
80
85
let tls_acceptor = tls_acceptor. clone ( ) ;
81
86
tokio:: spawn ( async move {
82
87
let tls_stream = match tls_acceptor. accept ( stream) . await {
@@ -90,7 +95,7 @@ pub async fn listen_tcp_with_tls(
90
95
. serve_connection (
91
96
TokioIo :: new ( tls_stream) ,
92
97
service_fn ( move |req| {
93
- serve_payjoin_directory ( req, pool. clone ( ) , ohttp. clone ( ) )
98
+ serve_payjoin_directory ( req, pool. clone ( ) , ohttp. clone ( ) , base_url . clone ( ) )
94
99
} ) ,
95
100
)
96
101
. with_upgrades ( )
@@ -143,6 +148,7 @@ async fn serve_payjoin_directory(
143
148
req : Request < Incoming > ,
144
149
pool : DbPool ,
145
150
ohttp : Arc < Mutex < ohttp:: Server > > ,
151
+ base_url : String ,
146
152
) -> Result < Response < BoxBody < Bytes , hyper:: Error > > > {
147
153
let path = req. uri ( ) . path ( ) . to_string ( ) ;
148
154
let query = req. uri ( ) . query ( ) . unwrap_or_default ( ) . to_string ( ) ;
@@ -151,7 +157,7 @@ async fn serve_payjoin_directory(
151
157
let path_segments: Vec < & str > = path. split ( '/' ) . collect ( ) ;
152
158
debug ! ( "serve_payjoin_directory: {:?}" , & path_segments) ;
153
159
let mut response = match ( parts. method , path_segments. as_slice ( ) ) {
154
- ( Method :: POST , [ "" , "" ] ) => handle_ohttp_gateway ( body, pool, ohttp) . await ,
160
+ ( Method :: POST , [ "" , "" ] ) => handle_ohttp_gateway ( body, pool, ohttp, base_url ) . await ,
155
161
( Method :: GET , [ "" , "ohttp-keys" ] ) => get_ohttp_keys ( & ohttp) . await ,
156
162
( Method :: POST , [ "" , id] ) => post_fallback_v1 ( id, query, body, pool) . await ,
157
163
( Method :: GET , [ "" , "health" ] ) => health_check ( ) . await ,
@@ -169,6 +175,7 @@ async fn handle_ohttp_gateway(
169
175
body : Incoming ,
170
176
pool : DbPool ,
171
177
ohttp : Arc < Mutex < ohttp:: Server > > ,
178
+ base_url : String ,
172
179
) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , HandlerError > {
173
180
// decapsulate
174
181
let ohttp_body =
@@ -194,10 +201,13 @@ async fn handle_ohttp_gateway(
194
201
}
195
202
let request = http_req. body ( full ( body) ) ?;
196
203
197
- let response = handle_v2 ( pool, request) . await ?;
204
+ let response = handle_v2 ( pool, base_url , request) . await ?;
198
205
199
206
let ( parts, body) = response. into_parts ( ) ;
200
207
let mut bhttp_res = bhttp:: Message :: response ( parts. status . as_u16 ( ) ) ;
208
+ for ( name, value) in parts. headers . iter ( ) {
209
+ bhttp_res. put_header ( name. as_str ( ) , value. to_str ( ) . unwrap_or_default ( ) ) ;
210
+ }
201
211
let full_body =
202
212
body. collect ( ) . await . map_err ( |e| HandlerError :: InternalServerError ( e. into ( ) ) ) ?. to_bytes ( ) ;
203
213
bhttp_res. write_content ( & full_body) ;
@@ -213,6 +223,7 @@ async fn handle_ohttp_gateway(
213
223
214
224
async fn handle_v2 (
215
225
pool : DbPool ,
226
+ base_url : String ,
216
227
req : Request < BoxBody < Bytes , hyper:: Error > > ,
217
228
) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , HandlerError > {
218
229
let path = req. uri ( ) . path ( ) . to_string ( ) ;
@@ -221,10 +232,10 @@ async fn handle_v2(
221
232
let path_segments: Vec < & str > = path. split ( '/' ) . collect ( ) ;
222
233
debug ! ( "handle_v2: {:?}" , & path_segments) ;
223
234
match ( parts. method , path_segments. as_slice ( ) ) {
224
- ( Method :: POST , & [ "" , "" ] ) => post_session ( body) . await ,
235
+ ( Method :: POST , & [ "" , "" ] ) => post_session ( base_url , body) . await ,
225
236
( Method :: POST , & [ "" , id] ) => post_fallback_v2 ( id, body, pool) . await ,
226
237
( Method :: GET , & [ "" , id] ) => get_fallback ( id, pool) . await ,
227
- ( Method :: POST , & [ "" , id, "payjoin" ] ) => post_payjoin ( id, body, pool) . await ,
238
+ ( Method :: PUT , & [ "" , id] ) => post_payjoin ( id, body, pool) . await ,
228
239
_ => Ok ( not_found ( ) ) ,
229
240
}
230
241
}
@@ -233,6 +244,7 @@ async fn health_check() -> Result<Response<BoxBody<Bytes, hyper::Error>>, Handle
233
244
Ok ( Response :: new ( empty ( ) ) )
234
245
}
235
246
247
+ #[ derive( Debug ) ]
236
248
enum HandlerError {
237
249
PayloadTooLarge ,
238
250
InternalServerError ( anyhow:: Error ) ,
@@ -273,6 +285,7 @@ impl From<hyper::http::Error> for HandlerError {
273
285
}
274
286
275
287
async fn post_session (
288
+ base_url : String ,
276
289
body : BoxBody < Bytes , hyper:: Error > ,
277
290
) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , HandlerError > {
278
291
let bytes = body. collect ( ) . await . map_err ( |e| HandlerError :: BadRequest ( e. into ( ) ) ) ?. to_bytes ( ) ;
@@ -283,9 +296,10 @@ async fn post_session(
283
296
let pubkey = bitcoin:: secp256k1:: PublicKey :: from_slice ( & pubkey_bytes)
284
297
. map_err ( |e| HandlerError :: BadRequest ( e. into ( ) ) ) ?;
285
298
tracing:: info!( "Initialized session with pubkey: {:?}" , pubkey) ;
286
- let mut res = Response :: new ( empty ( ) ) ;
287
- * res. status_mut ( ) = StatusCode :: NO_CONTENT ;
288
- Ok ( res)
299
+ Ok ( Response :: builder ( )
300
+ . header ( LOCATION , format ! ( "{}/{}" , base_url, pubkey) )
301
+ . status ( StatusCode :: CREATED )
302
+ . body ( empty ( ) ) ?)
289
303
}
290
304
291
305
async fn post_fallback_v1 (
@@ -413,3 +427,32 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
413
427
fn full < T : Into < Bytes > > ( chunk : T ) -> BoxBody < Bytes , hyper:: Error > {
414
428
Full :: new ( chunk. into ( ) ) . map_err ( |never| match never { } ) . boxed ( )
415
429
}
430
+
431
+ #[ cfg( test) ]
432
+ mod tests {
433
+ use hyper:: Request ;
434
+
435
+ use super :: * ;
436
+
437
+ /// Ensure that the POST / endpoint returns a 201 Created with a Location header
438
+ /// as is semantically correct when creating a resource.
439
+ ///
440
+ /// https://datatracker.ietf.org/doc/html/rfc9110#name-post
441
+ #[ tokio:: test]
442
+ async fn test_post_session ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
443
+ let base_url = "https://localhost" . to_string ( ) ;
444
+ let body = full ( "some_base64_encoded_pubkey" ) ;
445
+
446
+ let request = Request :: builder ( ) . method ( Method :: POST ) . uri ( "/" ) . body ( body) ?;
447
+
448
+ let response = post_session ( base_url. clone ( ) , request. into_body ( ) )
449
+ . await
450
+ . map_err ( |e| format ! ( "{:?}" , e) ) ?;
451
+
452
+ assert_eq ! ( response. status( ) , StatusCode :: CREATED ) ;
453
+ assert ! ( response. headers( ) . contains_key( LOCATION ) ) ;
454
+ let location_header = response. headers ( ) . get ( LOCATION ) . ok_or ( "Missing LOCATION header" ) ?;
455
+ assert ! ( location_header. to_str( ) ?. starts_with( & base_url) ) ;
456
+ Ok ( ( ) )
457
+ }
458
+ }
0 commit comments