Skip to content

Commit 1635bcc

Browse files
authored
feat: add {http1,http2}_only for auto conn (#111)
1 parent 4b24573 commit 1635bcc

File tree

1 file changed

+116
-12
lines changed

1 file changed

+116
-12
lines changed

src/server/conn/auto.rs

+116-12
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct Builder<E> {
5858
http1: http1::Builder,
5959
#[cfg(feature = "http2")]
6060
http2: http2::Builder<E>,
61+
#[cfg(any(feature = "http1", feature = "http2"))]
62+
version: Option<Version>,
6163
#[cfg(not(feature = "http2"))]
6264
_executor: E,
6365
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
8486
http1: http1::Builder::new(),
8587
#[cfg(feature = "http2")]
8688
http2: http2::Builder::new(executor),
89+
#[cfg(any(feature = "http1", feature = "http2"))]
90+
version: None,
8791
#[cfg(not(feature = "http2"))]
8892
_executor: executor,
8993
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101105
Http2Builder { inner: self }
102106
}
103107

108+
/// Only accepts HTTP/2
109+
///
110+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
111+
#[cfg(feature = "http2")]
112+
pub fn http2_only(mut self) -> Self {
113+
assert!(self.version.is_none());
114+
self.version = Some(Version::H2);
115+
self
116+
}
117+
118+
/// Only accepts HTTP/1
119+
///
120+
/// Does not do anything if used with [`serve_connection_with_upgrades`]
121+
#[cfg(feature = "http1")]
122+
pub fn http1_only(mut self) -> Self {
123+
assert!(self.version.is_none());
124+
self.version = Some(Version::H1);
125+
self
126+
}
127+
104128
/// Bind a connection together with a [`Service`].
105129
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
106130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112136
I: Read + Write + Unpin + 'static,
113137
E: HttpServerConnExec<S::Future, B>,
114138
{
115-
Connection {
116-
state: ConnState::ReadVersion {
139+
let state = match self.version {
140+
#[cfg(feature = "http1")]
141+
Some(Version::H1) => {
142+
let io = Rewind::new_buffered(io, Bytes::new());
143+
let conn = self.http1.serve_connection(io, service);
144+
ConnState::H1 { conn }
145+
}
146+
#[cfg(feature = "http2")]
147+
Some(Version::H2) => {
148+
let io = Rewind::new_buffered(io, Bytes::new());
149+
let conn = self.http2.serve_connection(io, service);
150+
ConnState::H2 { conn }
151+
}
152+
#[cfg(any(feature = "http1", feature = "http2"))]
153+
_ => ConnState::ReadVersion {
117154
read_version: read_version(io),
118155
builder: self,
119156
service: Some(service),
120157
},
121-
}
158+
};
159+
160+
Connection { state }
122161
}
123162

124163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148187
}
149188
}
150189

151-
#[derive(Copy, Clone)]
190+
#[derive(Copy, Clone, Debug)]
152191
enum Version {
153192
H1,
154193
H2,
@@ -906,7 +945,7 @@ mod tests {
906945
#[cfg(not(miri))]
907946
#[tokio::test]
908947
async fn http1() {
909-
let addr = start_server().await;
948+
let addr = start_server(false, false).await;
910949
let mut sender = connect_h1(addr).await;
911950

912951
let response = sender
@@ -922,7 +961,23 @@ mod tests {
922961
#[cfg(not(miri))]
923962
#[tokio::test]
924963
async fn http2() {
925-
let addr = start_server().await;
964+
let addr = start_server(false, false).await;
965+
let mut sender = connect_h2(addr).await;
966+
967+
let response = sender
968+
.send_request(Request::new(Empty::<Bytes>::new()))
969+
.await
970+
.unwrap();
971+
972+
let body = response.into_body().collect().await.unwrap().to_bytes();
973+
974+
assert_eq!(body, BODY);
975+
}
976+
977+
#[cfg(not(miri))]
978+
#[tokio::test]
979+
async fn http2_only() {
980+
let addr = start_server(false, true).await;
926981
let mut sender = connect_h2(addr).await;
927982

928983
let response = sender
@@ -935,6 +990,46 @@ mod tests {
935990
assert_eq!(body, BODY);
936991
}
937992

993+
#[cfg(not(miri))]
994+
#[tokio::test]
995+
async fn http2_only_fail_if_client_is_http1() {
996+
let addr = start_server(false, true).await;
997+
let mut sender = connect_h1(addr).await;
998+
999+
let _ = sender
1000+
.send_request(Request::new(Empty::<Bytes>::new()))
1001+
.await
1002+
.expect_err("should fail");
1003+
}
1004+
1005+
#[cfg(not(miri))]
1006+
#[tokio::test]
1007+
async fn http1_only() {
1008+
let addr = start_server(true, false).await;
1009+
let mut sender = connect_h1(addr).await;
1010+
1011+
let response = sender
1012+
.send_request(Request::new(Empty::<Bytes>::new()))
1013+
.await
1014+
.unwrap();
1015+
1016+
let body = response.into_body().collect().await.unwrap().to_bytes();
1017+
1018+
assert_eq!(body, BODY);
1019+
}
1020+
1021+
#[cfg(not(miri))]
1022+
#[tokio::test]
1023+
async fn http1_only_fail_if_client_is_http2() {
1024+
let addr = start_server(true, false).await;
1025+
let mut sender = connect_h2(addr).await;
1026+
1027+
let _ = sender
1028+
.send_request(Request::new(Empty::<Bytes>::new()))
1029+
.await
1030+
.expect_err("should fail");
1031+
}
1032+
9381033
#[cfg(not(miri))]
9391034
#[tokio::test]
9401035
async fn graceful_shutdown() {
@@ -1000,7 +1095,7 @@ mod tests {
10001095
sender
10011096
}
10021097

1003-
async fn start_server() -> SocketAddr {
1098+
async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
10041099
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
10051100
let listener = TcpListener::bind(addr).await.unwrap();
10061101

@@ -1011,11 +1106,20 @@ mod tests {
10111106
let (stream, _) = listener.accept().await.unwrap();
10121107
let stream = TokioIo::new(stream);
10131108
tokio::task::spawn(async move {
1014-
let _ = auto::Builder::new(TokioExecutor::new())
1015-
.http2()
1016-
.max_header_list_size(4096)
1017-
.serve_connection_with_upgrades(stream, service_fn(hello))
1018-
.await;
1109+
let mut builder = auto::Builder::new(TokioExecutor::new());
1110+
if h1_only {
1111+
builder = builder.http1_only();
1112+
builder.serve_connection(stream, service_fn(hello)).await;
1113+
} else if h2_only {
1114+
builder = builder.http2_only();
1115+
builder.serve_connection(stream, service_fn(hello)).await;
1116+
} else {
1117+
builder
1118+
.http2()
1119+
.max_header_list_size(4096)
1120+
.serve_connection_with_upgrades(stream, service_fn(hello))
1121+
.await;
1122+
}
10191123
});
10201124
}
10211125
});

0 commit comments

Comments
 (0)