@@ -58,6 +58,8 @@ pub struct Builder<E> {
58
58
http1 : http1:: Builder ,
59
59
#[ cfg( feature = "http2" ) ]
60
60
http2 : http2:: Builder < E > ,
61
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
62
+ version : Option < Version > ,
61
63
#[ cfg( not( feature = "http2" ) ) ]
62
64
_executor : E ,
63
65
}
@@ -84,6 +86,8 @@ impl<E> Builder<E> {
84
86
http1 : http1:: Builder :: new ( ) ,
85
87
#[ cfg( feature = "http2" ) ]
86
88
http2 : http2:: Builder :: new ( executor) ,
89
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
90
+ version : None ,
87
91
#[ cfg( not( feature = "http2" ) ) ]
88
92
_executor : executor,
89
93
}
@@ -101,6 +105,26 @@ impl<E> Builder<E> {
101
105
Http2Builder { inner : self }
102
106
}
103
107
108
+ /// Only accepts HTTP/2
109
+ ///
110
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
111
+ #[ cfg( feature = "http2" ) ]
112
+ pub fn http2_only ( mut self ) -> Self {
113
+ assert ! ( self . version. is_none( ) ) ;
114
+ self . version = Some ( Version :: H2 ) ;
115
+ self
116
+ }
117
+
118
+ /// Only accepts HTTP/1
119
+ ///
120
+ /// Does not do anything if used with [`serve_connection_with_upgrades`]
121
+ #[ cfg( feature = "http1" ) ]
122
+ pub fn http1_only ( mut self ) -> Self {
123
+ assert ! ( self . version. is_none( ) ) ;
124
+ self . version = Some ( Version :: H1 ) ;
125
+ self
126
+ }
127
+
104
128
/// Bind a connection together with a [`Service`].
105
129
pub fn serve_connection < I , S , B > ( & self , io : I , service : S ) -> Connection < ' _ , I , S , E >
106
130
where
@@ -112,13 +136,28 @@ impl<E> Builder<E> {
112
136
I : Read + Write + Unpin + ' static ,
113
137
E : HttpServerConnExec < S :: Future , B > ,
114
138
{
115
- Connection {
116
- state : ConnState :: ReadVersion {
139
+ let state = match self . version {
140
+ #[ cfg( feature = "http1" ) ]
141
+ Some ( Version :: H1 ) => {
142
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
143
+ let conn = self . http1 . serve_connection ( io, service) ;
144
+ ConnState :: H1 { conn }
145
+ }
146
+ #[ cfg( feature = "http2" ) ]
147
+ Some ( Version :: H2 ) => {
148
+ let io = Rewind :: new_buffered ( io, Bytes :: new ( ) ) ;
149
+ let conn = self . http2 . serve_connection ( io, service) ;
150
+ ConnState :: H2 { conn }
151
+ }
152
+ #[ cfg( any( feature = "http1" , feature = "http2" ) ) ]
153
+ _ => ConnState :: ReadVersion {
117
154
read_version : read_version ( io) ,
118
155
builder : self ,
119
156
service : Some ( service) ,
120
157
} ,
121
- }
158
+ } ;
159
+
160
+ Connection { state }
122
161
}
123
162
124
163
/// Bind a connection together with a [`Service`], with the ability to
@@ -148,7 +187,7 @@ impl<E> Builder<E> {
148
187
}
149
188
}
150
189
151
- #[ derive( Copy , Clone ) ]
190
+ #[ derive( Copy , Clone , Debug ) ]
152
191
enum Version {
153
192
H1 ,
154
193
H2 ,
@@ -906,7 +945,7 @@ mod tests {
906
945
#[ cfg( not( miri) ) ]
907
946
#[ tokio:: test]
908
947
async fn http1 ( ) {
909
- let addr = start_server ( ) . await ;
948
+ let addr = start_server ( false , false ) . await ;
910
949
let mut sender = connect_h1 ( addr) . await ;
911
950
912
951
let response = sender
@@ -922,7 +961,23 @@ mod tests {
922
961
#[ cfg( not( miri) ) ]
923
962
#[ tokio:: test]
924
963
async fn http2 ( ) {
925
- let addr = start_server ( ) . await ;
964
+ let addr = start_server ( false , false ) . await ;
965
+ let mut sender = connect_h2 ( addr) . await ;
966
+
967
+ let response = sender
968
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
969
+ . await
970
+ . unwrap ( ) ;
971
+
972
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
973
+
974
+ assert_eq ! ( body, BODY ) ;
975
+ }
976
+
977
+ #[ cfg( not( miri) ) ]
978
+ #[ tokio:: test]
979
+ async fn http2_only ( ) {
980
+ let addr = start_server ( false , true ) . await ;
926
981
let mut sender = connect_h2 ( addr) . await ;
927
982
928
983
let response = sender
@@ -935,6 +990,46 @@ mod tests {
935
990
assert_eq ! ( body, BODY ) ;
936
991
}
937
992
993
+ #[ cfg( not( miri) ) ]
994
+ #[ tokio:: test]
995
+ async fn http2_only_fail_if_client_is_http1 ( ) {
996
+ let addr = start_server ( false , true ) . await ;
997
+ let mut sender = connect_h1 ( addr) . await ;
998
+
999
+ let _ = sender
1000
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1001
+ . await
1002
+ . expect_err ( "should fail" ) ;
1003
+ }
1004
+
1005
+ #[ cfg( not( miri) ) ]
1006
+ #[ tokio:: test]
1007
+ async fn http1_only ( ) {
1008
+ let addr = start_server ( true , false ) . await ;
1009
+ let mut sender = connect_h1 ( addr) . await ;
1010
+
1011
+ let response = sender
1012
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1013
+ . await
1014
+ . unwrap ( ) ;
1015
+
1016
+ let body = response. into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
1017
+
1018
+ assert_eq ! ( body, BODY ) ;
1019
+ }
1020
+
1021
+ #[ cfg( not( miri) ) ]
1022
+ #[ tokio:: test]
1023
+ async fn http1_only_fail_if_client_is_http2 ( ) {
1024
+ let addr = start_server ( true , false ) . await ;
1025
+ let mut sender = connect_h2 ( addr) . await ;
1026
+
1027
+ let _ = sender
1028
+ . send_request ( Request :: new ( Empty :: < Bytes > :: new ( ) ) )
1029
+ . await
1030
+ . expect_err ( "should fail" ) ;
1031
+ }
1032
+
938
1033
#[ cfg( not( miri) ) ]
939
1034
#[ tokio:: test]
940
1035
async fn graceful_shutdown ( ) {
@@ -1000,7 +1095,7 @@ mod tests {
1000
1095
sender
1001
1096
}
1002
1097
1003
- async fn start_server ( ) -> SocketAddr {
1098
+ async fn start_server ( h1_only : bool , h2_only : bool ) -> SocketAddr {
1004
1099
let addr: SocketAddr = ( [ 127 , 0 , 0 , 1 ] , 0 ) . into ( ) ;
1005
1100
let listener = TcpListener :: bind ( addr) . await . unwrap ( ) ;
1006
1101
@@ -1011,11 +1106,20 @@ mod tests {
1011
1106
let ( stream, _) = listener. accept ( ) . await . unwrap ( ) ;
1012
1107
let stream = TokioIo :: new ( stream) ;
1013
1108
tokio:: task:: spawn ( async move {
1014
- let _ = auto:: Builder :: new ( TokioExecutor :: new ( ) )
1015
- . http2 ( )
1016
- . max_header_list_size ( 4096 )
1017
- . serve_connection_with_upgrades ( stream, service_fn ( hello) )
1018
- . await ;
1109
+ let mut builder = auto:: Builder :: new ( TokioExecutor :: new ( ) ) ;
1110
+ if h1_only {
1111
+ builder = builder. http1_only ( ) ;
1112
+ builder. serve_connection ( stream, service_fn ( hello) ) . await ;
1113
+ } else if h2_only {
1114
+ builder = builder. http2_only ( ) ;
1115
+ builder. serve_connection ( stream, service_fn ( hello) ) . await ;
1116
+ } else {
1117
+ builder
1118
+ . http2 ( )
1119
+ . max_header_list_size ( 4096 )
1120
+ . serve_connection_with_upgrades ( stream, service_fn ( hello) )
1121
+ . await ;
1122
+ }
1019
1123
} ) ;
1020
1124
}
1021
1125
} ) ;
0 commit comments