From 806b429bb9b15f6c24781f13a1e230dce196051f Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 14 Sep 2025 15:00:40 +0200 Subject: [PATCH] Define HTTPProtocol class --- tests/protocols/test_http.py | 142 +++++++++++------------ uvicorn/_types.py | 9 +- uvicorn/logging.py | 8 +- uvicorn/protocols/http/base.py | 82 +++++++++++++ uvicorn/protocols/http/h11_impl.py | 54 ++------- uvicorn/protocols/http/httptools_impl.py | 47 ++------ uvicorn/protocols/utils.py | 6 +- 7 files changed, 176 insertions(+), 172 deletions(-) create mode 100644 uvicorn/protocols/http/base.py diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 57804e799..6d62f9c80 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import logging import socket import threading @@ -14,6 +15,7 @@ from uvicorn.config import WS_PROTOCOLS, Config from uvicorn.lifespan.off import LifespanOff from uvicorn.lifespan.on import LifespanOn +from uvicorn.protocols.http.base import HTTPProtocol from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.server import ServerState @@ -36,7 +38,6 @@ else: # pragma: no cover from typing_extensions import TypeAlias - HTTPProtocol: TypeAlias = "type[HttpToolsProtocol | H11Protocol]" WSProtocol: TypeAlias = "type[WebSocketProtocol | _WSProtocol]" pytestmark = pytest.mark.anyio @@ -173,7 +174,9 @@ class MockTransport: - def __init__(self, sockname=None, peername=None, sslcontext=False): + def __init__( + self, sockname: tuple[str, int] | None = None, peername: tuple[str, int] | None = None, sslcontext: bool = False + ): self.sockname = ("127.0.0.1", 8000) if sockname is None else sockname self.peername = ("127.0.0.1", 8001) if peername is None else peername self.sslcontext = sslcontext @@ -181,14 +184,10 @@ def __init__(self, sockname=None, peername=None, sslcontext=False): self.buffer = b"" self.read_paused = False - def get_extra_info(self, key): - return { - "sockname": self.sockname, - "peername": self.peername, - "sslcontext": self.sslcontext, - }.get(key) + def get_extra_info(self, key: Any): + return {"sockname": self.sockname, "peername": self.peername, "sslcontext": self.sslcontext}.get(key) - def write(self, data): + def write(self, data: bytes): assert not self.closed self.buffer += data @@ -208,7 +207,7 @@ def is_closing(self): def clear_buffer(self): self.buffer = b"" - def set_protocol(self, protocol): + def set_protocol(self, protocol: asyncio.Protocol): pass @@ -258,12 +257,17 @@ def add_done_callback(self, callback): pass +class MockProtocol(HTTPProtocol): + loop: MockLoop # type: ignore[assignment] + transport: MockTransport # type: ignore[assignment] + + def get_connected_protocol( app: ASGIApplication, - http_protocol_cls: HTTPProtocol, + http_protocol_cls: type[HTTPProtocol], lifespan: LifespanOff | LifespanOn | None = None, **kwargs: Any, -): +) -> MockProtocol: loop = MockLoop() transport = MockTransport() config = Config(app=app, **kwargs) @@ -273,13 +277,13 @@ def get_connected_protocol( config=config, server_state=server_state, app_state=lifespan.state, - _loop=loop, # type: ignore + _loop=loop, # type: ignore[arg-type] ) - protocol.connection_made(transport) # type: ignore - return protocol + protocol.connection_made(transport) # type: ignore[arg-type] + return protocol # type: ignore[return-value] -async def test_get_request(http_protocol_cls: HTTPProtocol): +async def test_get_request(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -298,7 +302,7 @@ async def test_get_request(http_protocol_cls: HTTPProtocol): pytest.param("ยต", id="allow_non_ascii_char"), ], ) -async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol, char: str): +async def test_header_value_allowed_characters(http_protocol_cls: type[HTTPProtocol], char: str): app = Response("Hello, world", media_type="text/plain", headers={"key": f"<{char}>"}) protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) @@ -309,7 +313,7 @@ async def test_header_value_allowed_characters(http_protocol_cls: HTTPProtocol, @pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"]) -async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture): +async def test_request_logging(path: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture): get_request_with_query_string = b"\r\n".join( [f"GET {path} HTTP/1.1".encode("ascii"), b"Host: example.org", b"", b""] ) @@ -324,7 +328,7 @@ async def test_request_logging(path: str, http_protocol_cls: HTTPProtocol, caplo assert f'"GET {path} HTTP/1.1" 200' in caplog.records[0].message -async def test_head_request(http_protocol_cls: HTTPProtocol): +async def test_head_request(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -334,7 +338,7 @@ async def test_head_request(http_protocol_cls: HTTPProtocol): assert b"Hello, world" not in protocol.transport.buffer -async def test_post_request(http_protocol_cls: HTTPProtocol): +async def test_post_request(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" more_body = True @@ -353,7 +357,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b'Body: {"hello": "world"}' in protocol.transport.buffer -async def test_keepalive(http_protocol_cls: HTTPProtocol): +async def test_keepalive(http_protocol_cls: type[HTTPProtocol]): app = Response(b"", status_code=204) protocol = get_connected_protocol(app, http_protocol_cls) @@ -364,7 +368,7 @@ async def test_keepalive(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol): +async def test_keepalive_timeout(http_protocol_cls: type[HTTPProtocol]): app = Response(b"", status_code=204) protocol = get_connected_protocol(app, http_protocol_cls) @@ -378,9 +382,7 @@ async def test_keepalive_timeout(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_keepalive_timeout_with_pipelined_requests( - http_protocol_cls: HTTPProtocol, -): +async def test_keepalive_timeout_with_pipelined_requests(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -403,7 +405,7 @@ async def test_keepalive_timeout_with_pipelined_requests( assert protocol.timeout_keep_alive_task is not None -async def test_close(http_protocol_cls: HTTPProtocol): +async def test_close(http_protocol_cls: type[HTTPProtocol]): app = Response(b"", status_code=204, headers={"connection": "close"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -413,7 +415,7 @@ async def test_close(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): +async def test_chunked_encoding(http_protocol_cls: type[HTTPProtocol]): app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -424,7 +426,7 @@ async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): +async def test_chunked_encoding_empty_body(http_protocol_cls: type[HTTPProtocol]): app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -435,9 +437,7 @@ async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -async def test_chunked_encoding_head_request( - http_protocol_cls: HTTPProtocol, -): +async def test_chunked_encoding_head_request(http_protocol_cls: type[HTTPProtocol]): app = Response(b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -447,7 +447,7 @@ async def test_chunked_encoding_head_request( assert not protocol.transport.is_closing() -async def test_pipelined_requests(http_protocol_cls: HTTPProtocol): +async def test_pipelined_requests(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -468,7 +468,7 @@ async def test_pipelined_requests(http_protocol_cls: HTTPProtocol): protocol.transport.clear_buffer() -async def test_undersized_request(http_protocol_cls: HTTPProtocol): +async def test_undersized_request(http_protocol_cls: type[HTTPProtocol]): app = Response(b"xxx", headers={"content-length": "10"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -477,7 +477,7 @@ async def test_undersized_request(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_oversized_request(http_protocol_cls: HTTPProtocol): +async def test_oversized_request(http_protocol_cls: type[HTTPProtocol]): app = Response(b"xxx" * 20, headers={"content-length": "10"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -486,7 +486,7 @@ async def test_oversized_request(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_large_post_request(http_protocol_cls: HTTPProtocol): +async def test_large_post_request(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -496,7 +496,7 @@ async def test_large_post_request(http_protocol_cls: HTTPProtocol): assert not protocol.transport.read_paused -async def test_invalid_http(http_protocol_cls: HTTPProtocol): +async def test_invalid_http(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -504,7 +504,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_app_exception(http_protocol_cls: HTTPProtocol): +async def test_app_exception(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): raise Exception() @@ -515,7 +515,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_exception_during_response(http_protocol_cls: HTTPProtocol): +async def test_exception_during_response(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b"1", "more_body": True}) @@ -528,7 +528,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_no_response_returned(http_protocol_cls: HTTPProtocol): +async def test_no_response_returned(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ... protocol = get_connected_protocol(app, http_protocol_cls) @@ -538,7 +538,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_partial_response_returned(http_protocol_cls: HTTPProtocol): +async def test_partial_response_returned(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -549,7 +549,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_response_header_splitting(http_protocol_cls: HTTPProtocol): +async def test_response_header_splitting(http_protocol_cls: type[HTTPProtocol]): app = Response(b"", headers={"key": "value\r\nCookie: smuggled=value"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -560,7 +560,7 @@ async def test_response_header_splitting(http_protocol_cls: HTTPProtocol): assert protocol.transport.is_closing() -async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol): +async def test_duplicate_start_message(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.start", "status": 200}) @@ -572,7 +572,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_missing_start_message(http_protocol_cls: HTTPProtocol): +async def test_missing_start_message(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.body", "body": b""}) @@ -583,7 +583,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol): +async def test_message_after_body_complete(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) @@ -596,7 +596,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_value_returned(http_protocol_cls: HTTPProtocol): +async def test_value_returned(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) @@ -609,7 +609,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert protocol.transport.is_closing() -async def test_early_disconnect(http_protocol_cls: HTTPProtocol): +async def test_early_disconnect(http_protocol_cls: type[HTTPProtocol]): got_disconnect_event = False async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -630,7 +630,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert got_disconnect_event -async def test_early_response(http_protocol_cls: HTTPProtocol): +async def test_early_response(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -641,7 +641,7 @@ async def test_early_response(http_protocol_cls: HTTPProtocol): assert not protocol.transport.is_closing() -async def test_read_after_response(http_protocol_cls: HTTPProtocol): +async def test_read_after_response(http_protocol_cls: type[HTTPProtocol]): message_after_response = None async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -658,7 +658,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert message_after_response == {"type": "http.disconnect"} -async def test_http10_request(http_protocol_cls: HTTPProtocol): +async def test_http10_request(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" content = "Version: %s" % scope["http_version"] @@ -672,7 +672,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"Version: 1.0" in protocol.transport.buffer -async def test_root_path(http_protocol_cls: HTTPProtocol): +async def test_root_path(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" root_path = scope.get("root_path", "") @@ -687,7 +687,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"root_path=/app path=/app/" in protocol.transport.buffer -async def test_raw_path(http_protocol_cls: HTTPProtocol): +async def test_raw_path(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" path = scope["path"] @@ -704,7 +704,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"Done" in protocol.transport.buffer -async def test_max_concurrency(http_protocol_cls: HTTPProtocol): +async def test_max_concurrency(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls, limit_concurrency=1) @@ -725,27 +725,27 @@ async def test_max_concurrency(http_protocol_cls: HTTPProtocol): ) -async def test_shutdown_during_request(http_protocol_cls: HTTPProtocol): +async def test_shutdown_during_request(http_protocol_cls: type[HTTPProtocol]): app = Response(b"", status_code=204) protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) - protocol.shutdown() + protocol.shutdown() # type: ignore[attr-defined] await protocol.loop.run_one() assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer assert protocol.transport.is_closing() -async def test_shutdown_during_idle(http_protocol_cls: HTTPProtocol): +async def test_shutdown_during_idle(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) - protocol.shutdown() + protocol.shutdown() # type: ignore[attr-defined] assert protocol.transport.buffer == b"" assert protocol.transport.is_closing() -async def test_100_continue_sent_when_body_consumed(http_protocol_cls: HTTPProtocol): +async def test_100_continue_sent_when_body_consumed(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" more_body = True @@ -777,7 +777,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def test_100_continue_not_sent_when_body_not_consumed( - http_protocol_cls: HTTPProtocol, + http_protocol_cls: type[HTTPProtocol], ): app = Response(b"", status_code=204) @@ -799,7 +799,7 @@ async def test_100_continue_not_sent_when_body_not_consumed( assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer -async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol): +async def test_supported_upgrade_request(http_protocol_cls: type[HTTPProtocol]): pytest.importorskip("wsproto") app = Response("Hello, world", media_type="text/plain") @@ -809,7 +809,7 @@ async def test_supported_upgrade_request(http_protocol_cls: HTTPProtocol): assert b"HTTP/1.1 426 " in protocol.transport.buffer -async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol): +async def test_unsupported_ws_upgrade_request(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls, ws="none") @@ -820,7 +820,7 @@ async def test_unsupported_ws_upgrade_request(http_protocol_cls: HTTPProtocol): async def test_unsupported_ws_upgrade_request_warn_on_auto( - caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol + caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol] ): app = Response("Hello, world", media_type="text/plain") @@ -836,7 +836,7 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto( assert msg in warnings -async def test_http2_upgrade_request(http_protocol_cls: HTTPProtocol, ws_protocol_cls: WSProtocol): +async def test_http2_upgrade_request(http_protocol_cls: type[HTTPProtocol], ws_protocol_cls: WSProtocol): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls) @@ -867,7 +867,7 @@ async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable): async def test_scopes( asgi2or3_app: ASGIApplication, expected_scopes: dict[str, str], - http_protocol_cls: HTTPProtocol, + http_protocol_cls: type[HTTPProtocol], ): protocol = get_connected_protocol(asgi2or3_app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) @@ -884,7 +884,7 @@ async def test_scopes( ], ) async def test_invalid_http_request( - request_line: str, http_protocol_cls: HTTPProtocol, caplog: pytest.LogCaptureFixture + request_line: str, http_protocol_cls: type[HTTPProtocol], caplog: pytest.LogCaptureFixture ): app = Response("Hello, world", media_type="text/plain") request = INVALID_REQUEST_TEMPLATE % request_line @@ -1007,7 +1007,7 @@ async def test_huge_headers_h11_max_incomplete(): assert b"Hello, world" in protocol.transport.buffer -async def test_return_close_header(http_protocol_cls: HTTPProtocol): +async def test_return_close_header(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -1021,7 +1021,7 @@ async def test_return_close_header(http_protocol_cls: HTTPProtocol): assert b"connection: close" in protocol.transport.buffer.lower() -async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPProtocol): +async def test_close_connection_with_multiple_requests(http_protocol_cls: type[HTTPProtocol]): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -1035,7 +1035,7 @@ async def test_close_connection_with_multiple_requests(http_protocol_cls: HTTPPr assert b"connection: close" in protocol.transport.buffer.lower() -async def test_close_connection_with_post_request(http_protocol_cls: HTTPProtocol): +async def test_close_connection_with_post_request(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" more_body = True @@ -1054,7 +1054,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"Body: {'hello': 'world'}" in protocol.transport.buffer -async def test_iterator_headers(http_protocol_cls: HTTPProtocol): +async def test_iterator_headers(http_protocol_cls: type[HTTPProtocol]): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): headers = iter([(b"x-test-header", b"test value")]) await send({"type": "http.response.start", "status": 200, "headers": headers}) @@ -1066,7 +1066,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert b"x-test-header: test value" in protocol.transport.buffer -async def test_lifespan_state(http_protocol_cls: HTTPProtocol): +async def test_lifespan_state(http_protocol_cls: type[HTTPProtocol]): expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}] async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -1095,7 +1095,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def test_header_upgrade_is_not_websocket_depend_installed( - caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol + caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol] ): caplog.set_level(logging.WARNING, logger="uvicorn.error") app = Response("Hello, world", media_type="text/plain") @@ -1111,7 +1111,7 @@ async def test_header_upgrade_is_not_websocket_depend_installed( async def test_header_upgrade_is_websocket_depend_not_installed( - caplog: pytest.LogCaptureFixture, http_protocol_cls: HTTPProtocol + caplog: pytest.LogCaptureFixture, http_protocol_cls: type[HTTPProtocol] ): caplog.set_level(logging.WARNING, logger="uvicorn.error") app = Response("Hello, world", media_type="text/plain") diff --git a/uvicorn/_types.py b/uvicorn/_types.py index c927cc11d..3efc21ae0 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -270,12 +270,5 @@ async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) - ASGI2Application = type[ASGI2Protocol] -ASGI3Application = Callable[ - [ - Scope, - ASGIReceiveCallable, - ASGISendCallable, - ], - Awaitable[None], -] +ASGI3Application = Callable[[Scope, ASGIReceiveCallable, ASGISendCallable], Awaitable[None]] ASGIApplication = Union[ASGI2Application, ASGI3Application] diff --git a/uvicorn/logging.py b/uvicorn/logging.py index 02f455d8e..ba4913a94 100644 --- a/uvicorn/logging.py +++ b/uvicorn/logging.py @@ -96,13 +96,7 @@ def default(code: int) -> str: def formatMessage(self, record: logging.LogRecord) -> str: recordcopy = copy(record) - ( - client_addr, - method, - full_path, - http_version, - status_code, - ) = recordcopy.args # type: ignore[misc] + (client_addr, method, full_path, http_version, status_code) = recordcopy.args # type: ignore[misc] status_code = self.get_status_code(int(status_code)) # type: ignore[arg-type] request_line = f"{method} {full_path} HTTP/{http_version}" if self.use_colors: diff --git a/uvicorn/protocols/http/base.py b/uvicorn/protocols/http/base.py new file mode 100644 index 000000000..e8a0def73 --- /dev/null +++ b/uvicorn/protocols/http/base.py @@ -0,0 +1,82 @@ +from __future__ import annotations as _annotations + +import asyncio +import logging +from typing import Any + +from uvicorn._types import HTTPScope +from uvicorn.config import Config +from uvicorn.protocols.http.flow_control import FlowControl +from uvicorn.server import ServerState + + +class HTTPProtocol(asyncio.Protocol): + __slots__ = ( + "config", + "app", + "loop", + "logger", + "access_logger", + "access_log", + "ws_protocol_class", + "root_path", + "limit_concurrency", + "app_state", + # Timeouts + "timeout_keep_alive_task", + "timeout_keep_alive", + # Global state + "server_state", + "connections", + "tasks", + # Per-connection state + "transport", + "flow", + "server", + "client", + # Per-request state + "scope", + "headers", + ) + + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + + self.logger = logging.getLogger("uvicorn.error") + self.access_logger = logging.getLogger("uvicorn.access") + self.access_log = self.access_logger.hasHandlers() + + self.ws_protocol_class = config.ws_protocol_class + self.root_path = config.root_path + self.limit_concurrency = config.limit_concurrency + self.app_state = app_state + + # Timeouts + self.timeout_keep_alive_task: asyncio.TimerHandle | None = None + self.timeout_keep_alive = config.timeout_keep_alive + + # Global state + self.server_state = server_state + self.connections = server_state.connections + self.tasks = server_state.tasks + + # Per-connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.flow: FlowControl = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + + # Per-request state + self.scope: HTTPScope = None # type: ignore[assignment] + self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index b8cdde3ab..b05ef5f55 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -3,7 +3,7 @@ import asyncio import http import logging -from typing import Any, Callable, Literal, cast +from typing import Any, Callable, cast from urllib.parse import unquote import h11 @@ -20,6 +20,7 @@ ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.http.base import HTTPProtocol from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl from uvicorn.server import ServerState @@ -35,7 +36,7 @@ def _get_status_phrase(status_code: int) -> bytes: STATUS_PHRASES = {status_code: _get_status_phrase(status_code) for status_code in range(100, 600)} -class H11Protocol(asyncio.Protocol): +class H11Protocol(HTTPProtocol): def __init__( self, config: Config, @@ -43,55 +44,24 @@ def __init__( app_state: dict[str, Any], _loop: asyncio.AbstractEventLoop | None = None, ) -> None: - if not config.loaded: - config.load() - - self.config = config - self.app = config.loaded_app - self.loop = _loop or asyncio.get_event_loop() - self.logger = logging.getLogger("uvicorn.error") - self.access_logger = logging.getLogger("uvicorn.access") - self.access_log = self.access_logger.hasHandlers() + super().__init__(config, server_state, app_state, _loop) + self.conn = h11.Connection( h11.SERVER, config.h11_max_incomplete_event_size if config.h11_max_incomplete_event_size is not None else DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, ) - self.ws_protocol_class = config.ws_protocol_class - self.root_path = config.root_path - self.limit_concurrency = config.limit_concurrency - self.app_state = app_state - - # Timeouts - self.timeout_keep_alive_task: asyncio.TimerHandle | None = None - self.timeout_keep_alive = config.timeout_keep_alive - - # Shared server state - self.server_state = server_state - self.connections = server_state.connections - self.tasks = server_state.tasks - - # Per-connection state - self.transport: asyncio.Transport = None # type: ignore[assignment] - self.flow: FlowControl = None # type: ignore[assignment] - self.server: tuple[str, int] | None = None - self.client: tuple[str, int] | None = None - self.scheme: Literal["http", "https"] | None = None # Per-request state - self.scope: HTTPScope = None # type: ignore[assignment] - self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] self.cycle: RequestResponseCycle = None # type: ignore[assignment] # Protocol interface - def connection_made( # type: ignore[override] - self, transport: asyncio.Transport - ) -> None: + def connection_made(self, transport: asyncio.BaseTransport) -> None: self.connections.add(self) - self.transport = transport - self.flow = FlowControl(transport) + self.transport = cast(asyncio.Transport, transport) + self.flow = FlowControl(self.transport) self.server = get_local_addr(transport) self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" @@ -204,7 +174,7 @@ def handle_events(self) -> None: "http_version": event.http_version.decode("ascii"), "server": self.server, "client": self.client, - "scheme": self.scheme, # type: ignore[typeddict-item] + "scheme": self.scheme, "method": event.method.decode("ascii"), "root_path": self.root_path, "path": full_path, @@ -534,10 +504,6 @@ async def receive(self) -> ASGIReceiveEvent: if self.disconnected or self.response_complete: return {"type": "http.disconnect"} - message: HTTPRequestEvent = { - "type": "http.request", - "body": self.body, - "more_body": self.more_body, - } + message: HTTPRequestEvent = {"type": "http.request", "body": self.body, "more_body": self.more_body} self.body = b"" return message diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e8795ed35..5096153f6 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -5,9 +5,8 @@ import logging import re import urllib -from asyncio.events import TimerHandle from collections import deque -from typing import Any, Callable, Literal, cast +from typing import Any, Callable, cast import httptools @@ -21,6 +20,7 @@ ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.http.base import HTTPProtocol from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl from uvicorn.server import ServerState @@ -40,7 +40,7 @@ def _get_status_line(status_code: int) -> bytes: STATUS_LINE = {status_code: _get_status_line(status_code) for status_code in range(100, 600)} -class HttpToolsProtocol(asyncio.Protocol): +class HttpToolsProtocol(HTTPProtocol): def __init__( self, config: Config, @@ -48,15 +48,7 @@ def __init__( app_state: dict[str, Any], _loop: asyncio.AbstractEventLoop | None = None, ) -> None: - if not config.loaded: - config.load() - - self.config = config - self.app = config.loaded_app - self.loop = _loop or asyncio.get_event_loop() - self.logger = logging.getLogger("uvicorn.error") - self.access_logger = logging.getLogger("uvicorn.access") - self.access_log = self.access_logger.hasHandlers() + super().__init__(config, server_state, app_state, _loop) self.parser = httptools.HttpRequestParser(self) try: @@ -66,42 +58,19 @@ def __init__( # httptools < 0.6.3 pass - self.ws_protocol_class = config.ws_protocol_class - self.root_path = config.root_path - self.limit_concurrency = config.limit_concurrency - self.app_state = app_state - - # Timeouts - self.timeout_keep_alive_task: TimerHandle | None = None - self.timeout_keep_alive = config.timeout_keep_alive - - # Global state - self.server_state = server_state - self.connections = server_state.connections - self.tasks = server_state.tasks - # Per-connection state - self.transport: asyncio.Transport = None # type: ignore[assignment] - self.flow: FlowControl = None # type: ignore[assignment] - self.server: tuple[str, int] | None = None - self.client: tuple[str, int] | None = None - self.scheme: Literal["http", "https"] | None = None self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque() # Per-request state - self.scope: HTTPScope = None # type: ignore[assignment] - self.headers: list[tuple[bytes, bytes]] = None # type: ignore[assignment] self.expect_100_continue = False self.cycle: RequestResponseCycle = None # type: ignore[assignment] # Protocol interface - def connection_made( # type: ignore[override] - self, transport: asyncio.Transport - ) -> None: + def connection_made(self, transport: asyncio.BaseTransport) -> None: self.connections.add(self) - self.transport = transport - self.flow = FlowControl(transport) + self.transport = cast(asyncio.Transport, transport) + self.flow = FlowControl(self.transport) self.server = get_local_addr(transport) self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" @@ -226,7 +195,7 @@ def on_message_begin(self) -> None: "http_version": "1.1", "server": self.server, "client": self.client, - "scheme": self.scheme, # type: ignore[typeddict-item] + "scheme": self.scheme, "root_path": self.root_path, "headers": self.headers, "state": self.app_state.copy(), diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index e1d6f01d5..efb0cb0d9 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -9,7 +9,7 @@ class ClientDisconnected(OSError): ... -def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None: +def get_remote_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None: socket_info = transport.get_extra_info("socket") if socket_info is not None: try: @@ -26,7 +26,7 @@ def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None: return None -def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None: +def get_local_addr(transport: asyncio.BaseTransport) -> tuple[str, int] | None: socket_info = transport.get_extra_info("socket") if socket_info is not None: info = socket_info.getsockname() @@ -38,7 +38,7 @@ def get_local_addr(transport: asyncio.Transport) -> tuple[str, int] | None: return None -def is_ssl(transport: asyncio.Transport) -> bool: +def is_ssl(transport: asyncio.BaseTransport) -> bool: return bool(transport.get_extra_info("sslcontext"))