Skip to content

Commit 2f85ae9

Browse files
committedJan 26, 2024
feat: use tower service as impl of authenticate layer
1 parent 6cbb540 commit 2f85ae9

File tree

9 files changed

+120
-48
lines changed

9 files changed

+120
-48
lines changed
 

‎Cargo.lock

+3-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ anyhow = { version = "1" }
1313
clap = { version = "4.4" }
1414
chrono = { version = "0.4.31" }
1515
feed-rs = { version = "1.4" }
16-
futures = { version = "0.3" }
16+
futures-util = { version = "0.3.30" }
1717
graphql_client = { version = "0.13.0", default-features = false }
1818
moka = { version = "0.12.4", features = ["future"] }
1919
reqwest = { version = "0.11.23", default-features = false, features = [

‎synd/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ anyhow = { workspace = true }
1212
async-trait = { workspace = true }
1313
chrono = { workspace = true }
1414
feed-rs = { workspace = true }
15-
futures = { workspace = true }
15+
futures-util = { workspace = true }
1616
moka = { workspace = true, features = ["future"] }
1717
reqwest = { workspace = true, features = ["stream"] }
1818
thiserror = "1.0.56"

‎synd/src/feed/parser.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub struct FeedService {
4949
#[async_trait]
5050
impl FetchFeed for FeedService {
5151
async fn fetch_feed(&self, url: String) -> ParseResult<Feed> {
52-
use futures::StreamExt;
52+
use futures_util::StreamExt;
5353
let mut stream = self
5454
.http
5555
.get(&url)

‎syndapi/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ axum = { version = "0.7.4" }
1616
clap = { workspace = true, features = ["derive"] }
1717
chrono = { workspace = true }
1818
feed-rs = { workspace = true }
19+
futures-util = { workspace = true }
1920
graphql_client = { workspace = true }
2021
kvsd = { version = "0.1.2" }
22+
moka = { workspace = true, features = ["future"]}
2123
reqwest = { workspace = true }
2224
serde = { workspace = true }
2325
serde_json = "1.0.111"

‎syndapi/src/serve/auth.rs

+23-38
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
use axum::{
2-
extract::{Request, State},
3-
http::{self, StatusCode},
4-
middleware::Next,
5-
response::Response,
6-
};
1+
use std::time::Duration;
2+
3+
use moka::future::Cache;
74
use tracing::warn;
85

96
use crate::{
@@ -14,60 +11,48 @@ use crate::{
1411
#[derive(Clone)]
1512
pub struct Authenticator {
1613
github: GithubClient,
14+
cache: Cache<String, Principal>,
1715
}
1816

1917
impl Authenticator {
2018
pub fn new() -> anyhow::Result<Self> {
19+
let cache = Cache::builder()
20+
.max_capacity(1024 * 1024)
21+
.time_to_live(Duration::from_secs(60 * 60))
22+
.build();
23+
2124
Ok(Self {
2225
github: GithubClient::new()?,
26+
cache,
2327
})
2428
}
2529

2630
/// Authenticate from given token
27-
async fn authenticate(&self, token: &str) -> Result<Principal, ()> {
31+
pub async fn authenticate(&self, token: impl AsRef<str>) -> Result<Principal, ()> {
32+
let token = token.as_ref();
2833
let mut split = token.splitn(2, ' ');
2934
match (split.next(), split.next()) {
3035
(Some("github"), Some(access_token)) => {
31-
// TODO: configure cache to reduce api call
36+
if let Some(principal) = self.cache.get(token).await {
37+
tracing::info!("Principal cache hit");
38+
return Ok(principal);
39+
}
3240

3341
match self.github.authenticate(access_token).await {
34-
Ok(email) => Ok(Principal::User(User::from_email(email))),
42+
Ok(email) => {
43+
let principal = Principal::User(User::from_email(email));
44+
45+
self.cache.insert(token.to_owned(), principal.clone()).await;
46+
47+
Ok(principal)
48+
}
3549
Err(err) => {
3650
warn!("Failed to authenticate github {err}");
3751
Err(())
3852
}
3953
}
4054
}
41-
// TODO: remove
42-
(Some("me"), None) => Ok(Principal::User(User::from_email("me@ymgyt.io"))),
4355
_ => Err(()),
4456
}
4557
}
4658
}
47-
48-
/// Check authorization header and inject Authentication
49-
pub async fn authenticate(
50-
State(authenticator): State<Authenticator>,
51-
mut req: Request,
52-
next: Next,
53-
) -> Result<Response, StatusCode> {
54-
let header = req
55-
.headers()
56-
.get(http::header::AUTHORIZATION)
57-
.and_then(|header| header.to_str().ok());
58-
59-
let Some(token) = header else {
60-
return Err(StatusCode::UNAUTHORIZED);
61-
};
62-
let principal = match authenticator.authenticate(token).await {
63-
Ok(principal) => principal,
64-
Err(_) => {
65-
warn!("Invalid token");
66-
return Err(StatusCode::UNAUTHORIZED);
67-
}
68-
};
69-
70-
req.extensions_mut().insert(principal);
71-
72-
Ok(next.run(req).await)
73-
}
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use std::task::{Context, Poll};
2+
3+
use axum::http::{self, StatusCode};
4+
use axum::response::IntoResponse;
5+
use futures_util::future::BoxFuture;
6+
use tower::{Layer, Service};
7+
8+
use crate::serve::auth::Authenticator;
9+
10+
#[derive(Clone)]
11+
pub struct AuthenticateLayer {
12+
authenticator: Authenticator,
13+
}
14+
15+
impl AuthenticateLayer {
16+
pub fn new(authenticator: Authenticator) -> Self {
17+
Self { authenticator }
18+
}
19+
}
20+
21+
impl<S> Layer<S> for AuthenticateLayer {
22+
type Service = AuthenticateService<S>;
23+
24+
fn layer(&self, inner: S) -> Self::Service {
25+
AuthenticateService {
26+
inner,
27+
authenticator: self.authenticator.clone(),
28+
}
29+
}
30+
}
31+
32+
#[derive(Clone)]
33+
pub struct AuthenticateService<S> {
34+
authenticator: Authenticator,
35+
inner: S,
36+
}
37+
38+
impl<S> Service<axum::extract::Request> for AuthenticateService<S>
39+
where
40+
S: Service<axum::extract::Request, Response = axum::response::Response>
41+
+ Send
42+
+ 'static
43+
+ Clone,
44+
S::Future: Send + 'static,
45+
{
46+
type Response = S::Response;
47+
48+
type Error = S::Error;
49+
50+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
51+
52+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
53+
self.inner.poll_ready(cx)
54+
}
55+
56+
fn call(&mut self, mut request: axum::extract::Request) -> Self::Future {
57+
let header = request
58+
.headers()
59+
.get(http::header::AUTHORIZATION)
60+
.and_then(|header| header.to_str().ok());
61+
62+
let Some(token) = header else {
63+
return Box::pin(async { Ok(StatusCode::UNAUTHORIZED.into_response()) });
64+
};
65+
66+
let Self {
67+
authenticator,
68+
mut inner,
69+
} = self.clone();
70+
let token = token.to_owned();
71+
72+
Box::pin(async move {
73+
let principal = match authenticator.authenticate(token).await {
74+
Ok(principal) => principal,
75+
Err(_) => {
76+
tracing::warn!("Invalid token");
77+
return Ok(StatusCode::UNAUTHORIZED.into_response());
78+
}
79+
};
80+
81+
request.extensions_mut().insert(principal);
82+
83+
inner.call(request).await
84+
})
85+
}
86+
}

‎syndapi/src/serve/layer/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod audit;
2+
pub mod authenticate;
23
pub mod trace;

‎syndapi/src/serve/mod.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use async_graphql::{extensions::Tracing, EmptySubscription, Schema};
44
use axum::{
55
error_handling::HandleErrorLayer,
66
http::{header::AUTHORIZATION, StatusCode},
7-
middleware,
87
routing::{get, post},
98
BoxError, Extension, Router,
109
};
@@ -19,7 +18,7 @@ use crate::{
1918
config,
2019
dependency::Dependency,
2120
gql::{self, Mutation, Query},
22-
serve::layer::trace,
21+
serve::layer::{authenticate, trace},
2322
};
2423

2524
pub mod auth;
@@ -53,11 +52,8 @@ pub async fn serve(listener: TcpListener, dep: Dependency) -> anyhow::Result<()>
5352
let service = Router::new()
5453
.route("/graphql", post(gql::handler::graphql))
5554
.layer(Extension(schema))
56-
.route_layer(middleware::from_fn_with_state(
57-
authenticator,
58-
auth::authenticate,
59-
))
6055
.route("/graphql", get(gql::handler::graphiql))
56+
.layer(authenticate::AuthenticateLayer::new(authenticator))
6157
.layer(
6258
ServiceBuilder::new()
6359
.layer(SetSensitiveHeadersLayer::new(std::iter::once(

0 commit comments

Comments
 (0)