diff --git a/src/server/conn/auto/mod.rs b/src/server/conn/auto/mod.rs index 7b887ce2..9290aeb3 100644 --- a/src/server/conn/auto/mod.rs +++ b/src/server/conn/auto/mod.rs @@ -114,10 +114,6 @@ impl Builder { } /// Only accepts HTTP/2 - /// - /// Does not do anything if used with [`serve_connection_with_upgrades`] - /// - /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades #[cfg(feature = "http2")] pub fn http2_only(mut self) -> Self { assert!(self.version.is_none()); @@ -126,10 +122,6 @@ impl Builder { } /// Only accepts HTTP/1 - /// - /// Does not do anything if used with [`serve_connection_with_upgrades`] - /// - /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades #[cfg(feature = "http1")] pub fn http1_only(mut self) -> Self { assert!(self.version.is_none()); @@ -268,13 +260,28 @@ impl Builder { I: Read + Write + Unpin + Send + 'static, E: HttpServerConnExec, { - UpgradeableConnection { - state: UpgradeableConnState::ReadVersion { + let state = match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http1.serve_connection(io, service).with_upgrades(); + UpgradeableConnState::H1 { conn } + } + #[cfg(feature = "http2")] + Some(Version::H2) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http2.serve_connection(io, service); + UpgradeableConnState::H2 { conn } + } + #[cfg(any(feature = "http1", feature = "http2"))] + _ => UpgradeableConnState::ReadVersion { read_version: read_version(io), builder: Cow::Borrowed(self), service: Some(service), }, - } + }; + + UpgradeableConnection { state } } } @@ -1188,7 +1195,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http1() { - let addr = start_server(false, false).await; + let addr = start_server(HttpVersion::Any).await; let mut sender = connect_h1(addr).await; let response = sender @@ -1204,7 +1211,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http2() { - let addr = start_server(false, false).await; + let addr = start_server(HttpVersion::Any).await; let mut sender = connect_h2(addr).await; let response = sender @@ -1220,7 +1227,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http2_only() { - let addr = start_server(false, true).await; + let addr = start_server(HttpVersion::H2Only).await; let mut sender = connect_h2(addr).await; let response = sender @@ -1236,7 +1243,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http2_only_fail_if_client_is_http1() { - let addr = start_server(false, true).await; + let addr = start_server(HttpVersion::H2Only).await; let mut sender = connect_h1(addr).await; let _ = sender @@ -1248,7 +1255,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http1_only() { - let addr = start_server(true, false).await; + let addr = start_server(HttpVersion::H1Only).await; let mut sender = connect_h1(addr).await; let response = sender @@ -1264,7 +1271,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http1_only_fail_if_client_is_http2() { - let addr = start_server(true, false).await; + let addr = start_server(HttpVersion::H1Only).await; let mut sender = connect_h2(addr).await; let _ = sender @@ -1338,7 +1345,14 @@ mod tests { sender } - async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr { + #[derive(Clone, Copy)] + enum HttpVersion { + H1Only, + H2Only, + Any, + } + + async fn start_server(version: HttpVersion) -> SocketAddr { let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); let listener = TcpListener::bind(addr).await.unwrap(); @@ -1350,20 +1364,22 @@ mod tests { let stream = TokioIo::new(stream); tokio::task::spawn(async move { let mut builder = auto::Builder::new(TokioExecutor::new()); - if h1_only { - builder = builder.http1_only(); - builder.serve_connection(stream, service_fn(hello)).await - } else if h2_only { - builder = builder.http2_only(); - builder.serve_connection(stream, service_fn(hello)).await - } else { - builder - .http2() - .max_header_list_size(4096) - .serve_connection_with_upgrades(stream, service_fn(hello)) - .await - } - .unwrap(); + match version { + HttpVersion::H1Only => { + builder = builder.http1_only(); + } + HttpVersion::H2Only => { + builder = builder.http2_only(); + } + HttpVersion::Any => (), + }; + + builder + .http2() + .max_header_list_size(4096) + .serve_connection_with_upgrades(stream, service_fn(hello)) + .await + .unwrap(); }); } });