1
1
use anyhow:: { bail, Context } ;
2
2
use duration_str:: deserialize_duration;
3
- use serde:: Deserialize ;
3
+ use serde:: { Deserialize , Deserializer } ;
4
4
use small_acme:: LetsEncrypt ;
5
5
use std:: time:: Duration ;
6
6
use std:: { env, fs} ;
7
+ use std:: collections:: HashSet ;
8
+ use headers:: { HeaderValue , Origin } ;
7
9
use tracing:: warn;
8
10
9
11
const CONFIG_PATH : & str = "/config/config.toml" ;
@@ -12,7 +14,7 @@ const CONFIG_PATH: &str = "/config/config.toml";
12
14
pub struct Config {
13
15
pub file_dir : String ,
14
16
#[ serde( default ) ]
15
- pub cors : bool ,
17
+ pub cors : Option < HashSet < OriginWrapper > > ,
16
18
pub admin_config : Option < AdminConfig > ,
17
19
pub http : Option < HttpConfig > ,
18
20
pub https : Option < HttpsConfig > ,
@@ -94,7 +96,7 @@ fn default_max_upload_size() -> u64 {
94
96
#[ derive( Deserialize , Debug , Clone , PartialEq ) ]
95
97
pub struct DomainConfig {
96
98
pub domain : String ,
97
- pub cors : Option < bool > ,
99
+ pub cors : Option < HashSet < OriginWrapper > > ,
98
100
pub cache : Option < DomainCacheConfig > ,
99
101
pub https : Option < DomainHttpsConfig > ,
100
102
pub alias : Option < Vec < String > > ,
@@ -226,6 +228,29 @@ pub fn get_host_path_from_domain(domain: &str) -> (&str, &str) {
226
228
}
227
229
}
228
230
231
+ #[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
232
+ pub struct OriginWrapper ( HeaderValue ) ;
233
+
234
+ pub ( crate ) fn extract_origin ( data : & Option < HashSet < OriginWrapper > > ) -> Option < HashSet < HeaderValue > > {
235
+ data. as_ref ( ) . map ( |set| set. iter ( ) . map ( |o| o. 0 . clone ( ) ) . collect ( ) )
236
+ }
237
+
238
+ impl < ' de > Deserialize < ' de > for OriginWrapper {
239
+ fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
240
+ where
241
+ D : Deserializer < ' de >
242
+ {
243
+ let data = String :: deserialize ( deserializer) ?;
244
+ let mut parts = data. splitn ( 2 , "://" ) ;
245
+ let scheme = parts. next ( ) . expect ( "missing scheme" ) ;
246
+ let rest = parts. next ( ) . expect ( "missing scheme" ) ;
247
+ let origin = Origin :: try_from_parts ( scheme, rest, None ) . expect ( "invalid Origin" ) ;
248
+
249
+ Ok ( OriginWrapper ( origin. to_string ( ) . parse ( )
250
+ . expect ( "Origin is always a valid HeaderValue" ) ) )
251
+ }
252
+ }
253
+
229
254
#[ cfg( test) ]
230
255
mod test {
231
256
use std:: env;
0 commit comments