Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 48 additions & 32 deletions src/server/conn/auto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,6 @@ impl<E> Builder<E> {
}

/// Only accepts HTTP/2
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
///
/// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades
#[cfg(feature = "http2")]
pub fn http2_only(mut self) -> Self {
assert!(self.version.is_none());
Expand All @@ -126,10 +122,6 @@ impl<E> Builder<E> {
}

/// Only accepts HTTP/1
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
///
/// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades
#[cfg(feature = "http1")]
pub fn http1_only(mut self) -> Self {
assert!(self.version.is_none());
Expand Down Expand Up @@ -268,13 +260,28 @@ impl<E> Builder<E> {
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
UpgradeableConnection {
state: UpgradeableConnState::ReadVersion {
let state = match self.version {
#[cfg(feature = "http1")]
Some(Version::H1) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http1.serve_connection(io, service).with_upgrades();
UpgradeableConnState::H1 { conn }
}
#[cfg(feature = "http2")]
Some(Version::H2) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http2.serve_connection(io, service);
UpgradeableConnState::H2 { conn }
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => UpgradeableConnState::ReadVersion {
read_version: read_version(io),
builder: Cow::Borrowed(self),
service: Some(service),
},
}
};

UpgradeableConnection { state }
}
}

Expand Down Expand Up @@ -1188,7 +1195,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http1() {
let addr = start_server(false, false).await;
let addr = start_server(HttpVersion::Any).await;
let mut sender = connect_h1(addr).await;

let response = sender
Expand All @@ -1204,7 +1211,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http2() {
let addr = start_server(false, false).await;
let addr = start_server(HttpVersion::Any).await;
let mut sender = connect_h2(addr).await;

let response = sender
Expand All @@ -1220,7 +1227,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http2_only() {
let addr = start_server(false, true).await;
let addr = start_server(HttpVersion::H2Only).await;
let mut sender = connect_h2(addr).await;

let response = sender
Expand All @@ -1236,7 +1243,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http2_only_fail_if_client_is_http1() {
let addr = start_server(false, true).await;
let addr = start_server(HttpVersion::H2Only).await;
let mut sender = connect_h1(addr).await;

let _ = sender
Expand All @@ -1248,7 +1255,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http1_only() {
let addr = start_server(true, false).await;
let addr = start_server(HttpVersion::H1Only).await;
let mut sender = connect_h1(addr).await;

let response = sender
Expand All @@ -1264,7 +1271,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http1_only_fail_if_client_is_http2() {
let addr = start_server(true, false).await;
let addr = start_server(HttpVersion::H1Only).await;
let mut sender = connect_h2(addr).await;

let _ = sender
Expand Down Expand Up @@ -1338,7 +1345,14 @@ mod tests {
sender
}

async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
#[derive(Clone, Copy)]
enum HttpVersion {
H1Only,
H2Only,
Any,
}

async fn start_server(version: HttpVersion) -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();

Expand All @@ -1350,20 +1364,22 @@ mod tests {
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let mut builder = auto::Builder::new(TokioExecutor::new());
if h1_only {
builder = builder.http1_only();
builder.serve_connection(stream, service_fn(hello)).await
} else if h2_only {
builder = builder.http2_only();
builder.serve_connection(stream, service_fn(hello)).await
} else {
builder
.http2()
.max_header_list_size(4096)
.serve_connection_with_upgrades(stream, service_fn(hello))
.await
}
.unwrap();
match version {
HttpVersion::H1Only => {
builder = builder.http1_only();
}
HttpVersion::H2Only => {
builder = builder.http2_only();
}
HttpVersion::Any => (),
};

builder
.http2()
.max_header_list_size(4096)
.serve_connection_with_upgrades(stream, service_fn(hello))
.await
.unwrap();
});
}
});
Expand Down