diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a85f07a..c72cfb9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,8 +5,10 @@ Changelog ------------------ - **BREAKING**: Drop Python 3.8 support -- Bugfix: Run `socket.getfqdn` in thread to avoid blocking event loop - if `local_hostname` not provided (thanks @Raidzin) +- Feature: Added `SMTP.wait_closed` method to wait for the connection to close +- Change: Refactored and updated protocol logic +- Bugfix: Run `socket.getfqdn` in thread to avoid blocking event loop if + `local_hostname` not provided (thanks @Raidzin) 3.0.2 ----- diff --git a/src/aiosmtplib/protocol.py b/src/aiosmtplib/protocol.py index 25c7069..2641e82 100644 --- a/src/aiosmtplib/protocol.py +++ b/src/aiosmtplib/protocol.py @@ -6,11 +6,11 @@ import collections import re import ssl -from typing import Optional, cast +import weakref +from typing import Any, Callable, Optional, cast from .errors import ( SMTPDataError, - SMTPReadTimeoutError, SMTPResponseException, SMTPServerDisconnected, SMTPTimeoutError, @@ -27,6 +27,60 @@ PERIOD_REGEX = re.compile(rb"(?m)^\.") +def format_data_message(message: bytes) -> bytes: + message = LINE_ENDINGS_REGEX.sub(b"\r\n", message) + message = PERIOD_REGEX.sub(b"..", message) + if not message.endswith(b"\r\n"): + message += b"\r\n" + message += b".\r\n" + + return message + + +def read_response_from_buffer(data: bytearray) -> Optional[SMTPResponse]: + """Parse the actual SMTP response (if any) from the data buffer""" + code = -1 + message = bytearray() + offset = 0 + message_complete = False + + while True: + line_end_index = data.find(b"\n", offset) + if line_end_index == -1: + break + + line = bytes(data[offset : line_end_index + 1]) + + if len(line) > MAX_LINE_LENGTH: + raise SMTPResponseException( + SMTPStatus.unrecognized_command, "Response too long" + ) + + try: + code = int(line[:3]) + except ValueError: + error_text = line.decode("utf-8", errors="ignore") + raise SMTPResponseException( + SMTPStatus.invalid_response.value, + f"Malformed SMTP response line: {error_text}", + ) from None + + offset += len(line) + if len(message): + message.extend(b"\n") + message.extend(line[4:].strip(b" \t\r\n")) + if line[3:4] != b"-": + message_complete = True + break + + if message_complete: + response = SMTPResponse(code, bytes(message).decode("utf-8", "surrogateescape")) + del data[:offset] + return response + + return None + + class FlowControlMixin(asyncio.Protocol): """ Reusable flow control logic for StreamWriter.drain(). @@ -86,157 +140,220 @@ async def _drain_helper(self) -> None: finally: self._drain_waiters.remove(waiter) - def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]": + def _get_close_waiter( + self, stream: Optional[asyncio.StreamReader] + ) -> "asyncio.Future[None]": raise NotImplementedError -class SMTPProtocol(FlowControlMixin, asyncio.BaseProtocol): +class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + + Copied from stdlib, with some simplifications. + """ + def __init__( self, + stream_reader: asyncio.StreamReader, loop: Optional[asyncio.AbstractEventLoop] = None, - ) -> None: + ): super().__init__(loop=loop) + + self._stream_reader_wr = weakref.ref(stream_reader) + self._transport: Optional[asyncio.Transport] = None self._over_ssl = False - self._buffer = bytearray() - self._response_waiter: Optional[asyncio.Future[SMTPResponse]] = None + self._closed = self._loop.create_future() - self.transport: Optional[asyncio.BaseTransport] = None - self._command_lock: Optional[asyncio.Lock] = None - self._closed_future: "asyncio.Future[None]" = self._loop.create_future() - self._quit_sent = False + @property + def _stream_reader(self): + if self._stream_reader_wr is None: + return None + return self._stream_reader_wr() - def _get_close_waiter(self, stream: asyncio.StreamWriter) -> "asyncio.Future[None]": - return self._closed_future + def _replace_transport(self, transport: asyncio.Transport) -> None: + self._transport = transport + self._over_ssl = transport.get_extra_info("sslcontext") is not None + self._stream_reader._transport = transport # type: ignore - def __del__(self) -> None: - # Avoid 'Future exception was never retrieved' warnings - # Some unknown race conditions can sometimes trigger these :( - self._retrieve_response_exception() + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self._transport = cast(asyncio.Transport, transport) + reader = self._stream_reader + if reader is not None: + reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None + + def connection_lost(self, exc: Optional[Exception]) -> None: + reader = self._stream_reader + if reader is not None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + + super().connection_lost(exc) + + self._stream_reader_wr = None + self._transport = None + + def data_received(self, data: bytes) -> None: + reader = self._stream_reader + if reader is not None: + reader.feed_data(data) + + def eof_received(self): + reader = self._stream_reader + if reader is not None: + reader.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def _get_close_waiter( + self, stream: Optional[asyncio.StreamReader] + ) -> "asyncio.Future[None]": + return self._closed + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + try: + closed = self._closed + except AttributeError: + pass # failed constructor + else: + if closed.done() and not closed.cancelled(): + closed.exception() + + +class SMTPStreamWriter(asyncio.StreamWriter): + """A StreamWriter subclass for SMTP connections. + + Adds our own `start_tls` method, which is used to upgrade the connection on older + Python versions. + """ + + _loop: asyncio.AbstractEventLoop + _protocol: "SMTPProtocol" + + async def start_tls( + self, + sslcontext: ssl.SSLContext, + *, + server_hostname: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Upgrade an existing stream-based connection to TLS.""" + protocol = self._protocol + await self.drain() + new_transport = await self._loop.start_tls( + self._transport, # type: ignore + protocol, + sslcontext, + server_side=False, + server_hostname=server_hostname, + **kwargs, + ) + self._transport = new_transport + protocol._replace_transport(new_transport) # type: ignore - @property - def is_connected(self) -> bool: - """ - Check if our transport is still connected. - """ - return bool(self.transport is not None and not self.transport.is_closing()) + +class SMTPProtocol(StreamReaderProtocol): + def __init__( + self, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + reader = asyncio.StreamReader(limit=MAX_LINE_LENGTH, loop=loop) + + super().__init__(reader, loop=loop) + + self._reader = reader + self._writer = None + + self._command_lock: Optional[asyncio.Lock] = None + self._quit_sent: Optional[bool] = None def connection_made(self, transport: asyncio.BaseTransport) -> None: - self.transport = cast(asyncio.Transport, transport) - self._over_ssl = transport.get_extra_info("sslcontext") is not None - self._response_waiter = self._loop.create_future() + super().connection_made(transport) + self._writer = SMTPStreamWriter( + cast(asyncio.Transport, transport), + self, + self._stream_reader, + loop=self._loop, + ) + self._command_lock = asyncio.Lock() self._quit_sent = False def connection_lost(self, exc: Optional[Exception]) -> None: - super().connection_lost(exc) - - if not self._quit_sent: + smtp_exc = None + if exc: smtp_exc = SMTPServerDisconnected("Connection lost") if exc: smtp_exc.__cause__ = exc - if self._response_waiter and not self._response_waiter.done(): - self._response_waiter.set_exception(smtp_exc) + super().connection_lost(smtp_exc) - self.transport = None + command_lock = self._command_lock self._command_lock = None + if command_lock is not None and command_lock.locked(): + command_lock.release() - def data_received(self, data: bytes) -> None: - if self._response_waiter is None: - raise RuntimeError( - f"data_received called without a response waiter set: {data!r}" - ) - elif self._response_waiter.done(): - # We got a response without issuing a command; ignore it. - return + if self._writer: + self._writer.close() + self._writer = None + self._transport = None - self._buffer.extend(data) + def eof_received(self): + super().eof_received() - # If we got an obvious partial message, don't try to parse the buffer - last_linebreak = data.rfind(b"\n") - if ( - last_linebreak == -1 - or data[last_linebreak + 3 : last_linebreak + 4] == b"-" - ): - return + # Close the connection + return False - try: - response = self._read_response_from_buffer() - except Exception as exc: - self._response_waiter.set_exception(exc) - else: - if response is not None: - self._response_waiter.set_result(response) + def get_transport_info(self, key: str) -> Any: + if self._transport is None: + return None + return self._transport.get_extra_info(key) - def eof_received(self) -> bool: - exc = SMTPServerDisconnected("Unexpected EOF received") - if self._response_waiter and not self._response_waiter.done(): - self._response_waiter.set_exception(exc) + def _replace_transport(self, transport: asyncio.Transport) -> None: + super()._replace_transport(transport) + self._writer._transport = transport # type: ignore - # Returning false closes the transport - return False + def close(self, callback: Optional[Callable[[asyncio.Future[None]], Any]]) -> None: + if self._writer is not None: + self._writer.close() - def _retrieve_response_exception(self) -> Optional[BaseException]: - """ - Return any exception that has been set on the response waiter. + if callback: + self._get_close_waiter(None).add_done_callback(callback) - Used to avoid 'Future exception was never retrieved' warnings - """ - if ( - self._response_waiter - and self._response_waiter.done() - and not self._response_waiter.cancelled() - ): - return self._response_waiter.exception() - - return None - - def _read_response_from_buffer(self) -> Optional[SMTPResponse]: - """Parse the actual response (if any) from the data buffer""" - code = -1 - message = bytearray() - offset = 0 - message_complete = False - - while True: - line_end_index = self._buffer.find(b"\n", offset) - if line_end_index == -1: - break - - line = bytes(self._buffer[offset : line_end_index + 1]) - - if len(line) > MAX_LINE_LENGTH: - raise SMTPResponseException( - SMTPStatus.unrecognized_command, "Response too long" - ) + async def wait_closed(self) -> None: + if self._writer is not None: + await self._writer.wait_closed() + await self._get_close_waiter(None) - try: - code = int(line[:3]) - except ValueError: - error_text = line.decode("utf-8", errors="ignore") - raise SMTPResponseException( - SMTPStatus.invalid_response.value, - f"Malformed SMTP response line: {error_text}", - ) from None - - offset += len(line) - if len(message): - message.extend(b"\n") - message.extend(line[4:].strip(b" \t\r\n")) - if line[3:4] != b"-": - message_complete = True - break - - if message_complete: - response = SMTPResponse( - code, bytes(message).decode("utf-8", "surrogateescape") - ) - del self._buffer[:offset] - return response - else: - return None + def is_closing(self) -> bool: + return self._transport is None or self._transport.is_closing() + + @property + def is_connected(self) -> bool: + """ + Check if our transport is still connected. + """ + return not self.is_closing() - async def read_response(self, timeout: Optional[float] = None) -> SMTPResponse: + async def read_response(self) -> SMTPResponse: """ Get a status response from the server. @@ -248,58 +365,61 @@ async def read_response(self, timeout: Optional[float] = None) -> SMTPResponse: - server response string (multiline responses are converted to a single, multiline string). """ - if self._response_waiter is None: - raise SMTPServerDisconnected("Connection lost") - try: - result = await asyncio.wait_for(self._response_waiter, timeout) - except (TimeoutError, asyncio.TimeoutError) as exc: - raise SMTPReadTimeoutError("Timed out waiting for server response") from exc - finally: - # If we were disconnected, don't create a new waiter - if self.transport is None: - self._response_waiter = None + buffer = bytearray() + + response = None + while response is None: + try: + data = await self._reader.readuntil(b"\n") + except asyncio.IncompleteReadError as exc: + buffer.extend(exc.partial) else: - self._response_waiter = self._loop.create_future() + buffer.extend(data) - return result + response = read_response_from_buffer(buffer) - def write(self, data: bytes) -> None: - if self.transport is None or self.transport.is_closing(): - raise SMTPServerDisconnected("Connection lost") + if response: + return response - try: - cast(asyncio.WriteTransport, self.transport).write(data) - # uvloop raises NotImplementedError, asyncio doesn't have a write method - except (AttributeError, NotImplementedError): - raise RuntimeError( - f"Transport {self.transport!r} does not support writing." - ) from None + if self.is_closing() or self._reader.at_eof(): + raise SMTPServerDisconnected("Server disconnected") - async def execute_command( - self, *args: bytes, timeout: Optional[float] = None - ) -> SMTPResponse: + raise RuntimeError("No response from server") + + async def write(self, data: bytes) -> None: + if self._writer is None: + raise RuntimeError("Writer not initialized") + + self._writer.write(data) + await self._writer.drain() + + async def execute_command(self, *args: bytes, quit: bool = False) -> SMTPResponse: """ Sends an SMTP command along with any args to the server, and returns a response. """ + if self._writer is None or self._writer.is_closing(): + raise SMTPServerDisconnected("Connection lost") if self._command_lock is None: - raise SMTPServerDisconnected("Server not connected") + raise RuntimeError("Command lock not initialized") + command = b" ".join(args) + b"\r\n" async with self._command_lock: - self.write(command) + try: + await self.write(command) + except ConnectionResetError as exc: + raise SMTPServerDisconnected("Connection lost") from exc - if command == b"QUIT\r\n": + if quit: self._quit_sent = True - response = await self.read_response(timeout=timeout) + response = await self.read_response() return response - async def execute_data_command( - self, message: bytes, timeout: Optional[float] = None - ) -> SMTPResponse: + async def execute_data_command(self, message: bytes) -> SMTPResponse: """ Sends an SMTP DATA command to the server, followed by encoded message content. @@ -307,23 +427,23 @@ async def execute_data_command( Lone \\\\r and \\\\n characters are converted to \\\\r\\\\n characters. """ + if self._writer is None or self._writer.is_closing(): + raise SMTPServerDisconnected("Connection lost") if self._command_lock is None: - raise SMTPServerDisconnected("Server not connected") + raise RuntimeError("Command lock not initialized") - message = LINE_ENDINGS_REGEX.sub(b"\r\n", message) - message = PERIOD_REGEX.sub(b"..", message) - if not message.endswith(b"\r\n"): - message += b"\r\n" - message += b".\r\n" + formatted_message = format_data_message(message) async with self._command_lock: - self.write(b"DATA\r\n") - start_response = await self.read_response(timeout=timeout) + await self.write(b"DATA\r\n") + + start_response = await self.read_response() if start_response.code != SMTPStatus.start_input: raise SMTPDataError(start_response.code, start_response.message) - self.write(message) - response = await self.read_response(timeout=timeout) + await self.write(formatted_message) + + response = await self.read_response() if response.code != SMTPStatus.completed: raise SMTPDataError(response.code, response.message) @@ -338,27 +458,26 @@ async def start_tls( """ Puts the connection to the SMTP server into TLS mode. """ + if self._writer is None or self._writer.is_closing(): + raise SMTPServerDisconnected("Connection lost") if self._over_ssl: - raise RuntimeError("Already using TLS.") + raise RuntimeError("Already using TLS") if self._command_lock is None: - raise SMTPServerDisconnected("Server not connected") + raise RuntimeError("Command lock not initialized") async with self._command_lock: - self.write(b"STARTTLS\r\n") - response = await self.read_response(timeout=timeout) + await self.write(b"STARTTLS\r\n") + response = await self.read_response() if response.code != SMTPStatus.ready: raise SMTPResponseException(response.code, response.message) # Check for disconnect after response - if self.transport is None or self.transport.is_closing(): + if self._writer.is_closing(): raise SMTPServerDisconnected("Connection lost") try: - tls_transport = await self._loop.start_tls( - cast(asyncio.WriteTransport, self.transport), - self, + await self._writer.start_tls( tls_context, - server_side=False, server_hostname=server_hostname, ssl_handshake_timeout=timeout, ) @@ -374,6 +493,4 @@ async def start_tls( "Connection reset while upgrading transport" ) from exc - self.transport = tls_transport - return response diff --git a/src/aiosmtplib/smtp.py b/src/aiosmtplib/smtp.py index c9cf076..e1eb61e 100644 --- a/src/aiosmtplib/smtp.py +++ b/src/aiosmtplib/smtp.py @@ -141,7 +141,6 @@ def __init__( :raises ValueError: mutually exclusive options provided """ self.protocol: Optional[SMTPProtocol] = None - self.transport: Optional[asyncio.BaseTransport] = None # Kwarg defaults are provided here, and saved for connect. self.hostname = hostname @@ -179,16 +178,20 @@ async def __aenter__(self) -> "SMTP": return self async def __aexit__( - self, exc_type: type[BaseException], exc: BaseException, traceback: Any + self, + exc_type: type[BaseException], + exc: Optional[BaseException], + traceback: Any, ) -> None: if isinstance(exc, (ConnectionError, TimeoutError)): self.close() return - try: - await self.quit() - except (SMTPServerDisconnected, SMTPResponseException, SMTPTimeoutError): - pass + if not (self.protocol is None or self.protocol.is_closing()): + try: + await self.quit() + except (SMTPServerDisconnected, SMTPResponseException, SMTPTimeoutError): + pass @property def is_connected(self) -> bool: @@ -443,7 +446,8 @@ async def connect( await self._maybe_start_tls_on_connect() await self._maybe_login_on_connect() except Exception as exc: - self.close() # Reset our state to disconnected + # Reset our state to disconnected + self.close() raise exc return response @@ -490,7 +494,7 @@ async def _create_connection(self, timeout: Optional[float]) -> SMTPResponse: ) try: - transport, _ = await asyncio.wait_for(connect_coro, timeout=timeout) + await asyncio.wait_for(connect_coro, timeout=timeout) except (TimeoutError, asyncio.TimeoutError) as exc: raise SMTPConnectTimeoutError( f"Timed out connecting to {self.hostname} on port {self.port}" @@ -501,18 +505,17 @@ async def _create_connection(self, timeout: Optional[float]) -> SMTPResponse: ) from exc self.protocol = protocol - self.transport = transport try: - response = await protocol.read_response(timeout=timeout) + response = await asyncio.wait_for(protocol.read_response(), timeout) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise SMTPConnectTimeoutError( + "Timed out waiting for server ready message" + ) from exc except SMTPServerDisconnected as exc: raise SMTPConnectError( f"Error connecting to {self.hostname} on port {self.port}: {exc}" ) from exc - except SMTPTimeoutError as exc: - raise SMTPConnectTimeoutError( - "Timed out waiting for server ready message" - ) from exc if response.code != SMTPStatus.ready: raise SMTPConnectResponseError(response.code, response.message) @@ -558,9 +561,14 @@ async def execute_command( if self.protocol is None: raise SMTPServerDisconnected("Server not connected") - response = await self.protocol.execute_command( - *args, timeout=self.timeout if timeout is Default.token else timeout - ) + timeout = self.timeout if timeout is Default.token else timeout + + try: + response = await asyncio.wait_for( + self.protocol.execute_command(*args), timeout + ) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise SMTPTimeoutError("Timed out waiting for response") from exc # If the server is unavailable, be nice and close the connection if response.code == SMTPStatus.domain_unavailable: @@ -595,18 +603,25 @@ def close(self) -> None: """ Closes the connection. """ - if self.transport is not None and not self.transport.is_closing(): - self.transport.close() + if self.protocol is not None and not self.protocol.is_closing(): + self.protocol.close(self._unset_protocol) if self._connect_lock is not None and self._connect_lock.locked(): self._connect_lock.release() - self.protocol = None - self.transport = None - # Reset ESMTP state self._reset_server_state() + def _unset_protocol(self, future: asyncio.Future[None]) -> None: + self.protocol = None + + async def wait_closed(self) -> None: + """ + Wait for the connection to finish closing. To be awaited after calling `close`. + """ + if self.protocol is not None: + await self.protocol.wait_closed() + def get_transport_info(self, key: str) -> Any: """ Get extra info from the transport. @@ -623,10 +638,10 @@ def get_transport_info(self, key: str) -> Any: :raises SMTPServerDisconnected: connection lost """ - if not (self.is_connected and self.transport): + if not (self.is_connected and self.protocol): raise SMTPServerDisconnected("Server not connected") - return self.transport.get_extra_info(key) + return self.protocol.get_transport_info(key) # Base SMTP commands # @@ -799,11 +814,24 @@ async def quit( :raises SMTPResponseException: on unexpected server response code """ - response = await self.execute_command(b"QUIT", timeout=timeout) + + if self.protocol is None: + raise SMTPServerDisconnected("Server not connected") + + timeout = self.timeout if timeout is Default.token else timeout + + try: + response = await asyncio.wait_for( + self.protocol.execute_command(b"QUIT", quit=True), timeout + ) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise SMTPTimeoutError("Timed out waiting for response") from exc + if response.code != SMTPStatus.closing: raise SMTPResponseException(response.code, response.message) self.close() + await self.wait_closed() return response @@ -899,7 +927,14 @@ async def data( if isinstance(message, str): message = message.encode("ascii") - return await self.protocol.execute_data_command(message, timeout=timeout) + try: + response = await asyncio.wait_for( + self.protocol.execute_data_command(message), timeout + ) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise SMTPTimeoutError("Timed out waiting for response") from exc + + return response # ESMTP commands # @@ -1024,12 +1059,15 @@ async def starttls( if not self.supports_extension("starttls"): raise SMTPException("SMTP STARTTLS extension not supported by server.") - response = await self.protocol.start_tls( - tls_context, server_hostname=server_hostname, timeout=timeout - ) - - # Update our transport reference - self.transport = self.protocol.transport + try: + response = await asyncio.wait_for( + self.protocol.start_tls( + tls_context, server_hostname=server_hostname, timeout=timeout + ), + timeout, + ) + except (TimeoutError, asyncio.TimeoutError) as exc: + raise SMTPTimeoutError("Timed out waiting for response") from exc # RFC 3207 part 4.2: # The client MUST discard any knowledge obtained from the server, such diff --git a/tests/smtpd.py b/tests/smtpd.py index bde1908..f10fadb 100644 --- a/tests/smtpd.py +++ b/tests/smtpd.py @@ -104,15 +104,33 @@ async def smtp_STARTTLS(self, arg: str) -> None: self.event_handler.record_command("STARTTLS", arg) -async def mock_response_delayed_ok(smtpd: SMTPD, *args: Any, **kwargs: Any) -> None: +async def mock_response_delayed_ok_with_cleanup( + smtpd: SMTPD, *args: Any, **kwargs: Any +) -> None: await asyncio.sleep(1.0) await smtpd.push("250 all done") + if smtpd._handler_coroutine: + handler = smtpd._handler_coroutine + handler.cancel() + await handler + if smtpd.transport: + smtpd.transport.close() + -async def mock_response_delayed_read(smtpd: SMTPD, *args: Any, **kwargs: Any) -> None: +async def mock_response_delayed_read_with_cleanup( + smtpd: SMTPD, *args: Any, **kwargs: Any +) -> None: await smtpd.push("220-hi") await asyncio.sleep(1.0) + if smtpd._handler_coroutine: + handler = smtpd._handler_coroutine + handler.cancel() + await handler + if smtpd.transport: + smtpd.transport.close() + async def mock_response_done(smtpd: SMTPD, *args: Any, **kwargs: Any) -> None: if args and args[0]: @@ -126,7 +144,6 @@ async def mock_response_done_then_close( if args and args[0]: smtpd.session.host_name = args[0] await smtpd.push("250 done") - await smtpd.push("221 bye now") smtpd.transport.close() @@ -239,6 +256,8 @@ async def mock_response_syntax_error_and_cleanup( await smtpd.push(f"{SMTPStatus.syntax_error} error") if smtpd._handler_coroutine: - smtpd._handler_coroutine.cancel() + handler = smtpd._handler_coroutine + handler.cancel() + await handler if smtpd.transport: smtpd.transport.close() diff --git a/tests/test_api.py b/tests/test_api.py index 953250e..cebaca8 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -26,6 +26,7 @@ async def test_send( hostname=hostname, port=smtpd_server_port, tls_context=client_tls_context, + timeout=1.0, ) assert not errors @@ -47,6 +48,7 @@ async def test_send_with_str( sender=sender_str, recipients=[recipient_str], start_tls=False, + timeout=1.0, ) assert not errors @@ -68,6 +70,7 @@ async def test_send_with_bytes( sender=sender_str, recipients=[recipient_str], start_tls=False, + timeout=1.0, ) assert not errors @@ -107,6 +110,7 @@ async def test_send_without_recipients( sender=sender_str, recipients=[], start_tls=False, + timeout=1.0, ) @@ -124,6 +128,7 @@ async def test_send_with_start_tls( port=smtpd_server_port, start_tls=True, tls_context=client_tls_context, + timeout=1.0, ) assert not errors @@ -149,6 +154,7 @@ async def test_send_with_login( tls_context=client_tls_context, username=auth_username, password=auth_password, + timeout=1.0, ) assert not errors @@ -171,13 +177,12 @@ async def test_send_via_socket( port=None, sock=sock, start_tls=False, + timeout=1.0, ) assert not errors assert len(received_messages) == 1 - assert sock.fileno() > 0, "Socket unexpectedly closed" - async def test_send_via_socket_path( smtpd_server_socket_path: asyncio.AbstractServer, @@ -191,6 +196,7 @@ async def test_send_via_socket_path( port=None, socket_path=socket_path, start_tls=False, + timeout=1.0, ) assert not errors @@ -213,6 +219,7 @@ async def test_send_with_mail_options( recipients=[recipient_str], mail_options=["BODY=8BITMIME"], start_tls=False, + timeout=1.0, ) assert not errors @@ -236,6 +243,7 @@ async def test_send_with_rcpt_options( # RCPT params are not supported by the server; just check that the kwarg works rcpt_options=[], start_tls=False, + timeout=1.0, ) assert not errors diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 4073b3b..fad26e6 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -177,7 +177,7 @@ async def test_close_works_on_stopped_loop( await client.connect() assert client.is_connected - assert client.transport is not None + assert client.protocol is not None event_loop.stop() diff --git a/tests/test_connect.py b/tests/test_connect.py index 8a946e4..73148c4 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -47,9 +47,7 @@ async def test_plain_smtp_connect( assert not smtp_client.is_connected -async def test_quit_then_connect_ok( - smtp_client: SMTP, smtpd_server: asyncio.AbstractServer -) -> None: +async def test_quit_then_connect_ok(smtp_client: SMTP) -> None: async with smtp_client: response = await smtp_client.quit() assert response.code == SMTPStatus.closing @@ -70,8 +68,7 @@ async def test_bad_connect_response_raises_error(smtp_client: SMTP) -> None: with pytest.raises(SMTPConnectError): await smtp_client.connect() - assert smtp_client.transport is None - assert smtp_client.protocol is None + assert not smtp_client.is_connected @pytest.mark.smtpd_mocks(_handle_client=mock_response_eof) @@ -79,8 +76,7 @@ async def test_eof_on_connect_raises_connect_error(smtp_client: SMTP) -> None: with pytest.raises(SMTPConnectError): await smtp_client.connect() - assert smtp_client.transport is None - assert smtp_client.protocol is None + assert not smtp_client.is_connected @pytest.mark.smtpd_mocks(_handle_client=mock_response_disconnect) @@ -116,7 +112,7 @@ async def test_connect_error_with_no_server( async def test_disconnected_server_raises_on_client_read(smtp_client: SMTP) -> None: await smtp_client.connect() - with pytest.raises(SMTPServerDisconnected): + with pytest.raises(SMTPServerDisconnected, match="Server disconnected"): await smtp_client.execute_command(b"NOOP") assert not smtp_client.is_connected @@ -217,6 +213,7 @@ async def test_context_manager_exception_quits( async with smtp_client: 1 / 0 # noqa + assert len(received_commands) >= 1 assert received_commands[-1][0] == "QUIT" @@ -308,7 +305,7 @@ async def test_connect_with_no_starttls_support(smtp_client: SMTP) -> None: await smtp_client.connect() assert smtp_client.is_connected - assert not smtp_client.protocol._over_ssl + assert smtp_client.get_transport_info("sslcontext") is None await smtp_client.quit() diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 08e9754..896e0af 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -10,7 +10,7 @@ import pytest -from aiosmtplib import SMTPResponseException, SMTPServerDisconnected, SMTPTimeoutError +from aiosmtplib import SMTPResponseException, SMTPServerDisconnected from aiosmtplib.protocol import FlowControlMixin, SMTPProtocol from .compat import cleanup_server @@ -23,7 +23,7 @@ async def test_protocol_connect(hostname: str, echo_server_port: int) -> None: ) transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - assert getattr(protocol, "transport", None) is transport + assert getattr(protocol, "_transport", None) is transport assert not transport.is_closing() transport.close() @@ -61,7 +61,7 @@ async def client_connected( monkeypatch.setattr("aiosmtplib.protocol.MAX_LINE_LENGTH", 128) with pytest.raises(SMTPResponseException) as exc_info: - await protocol.execute_command(b"TEST\n", timeout=1.0) # type: ignore + await protocol.execute_command(b"TEST\n") assert exc_info.value.code == 500 assert "Response too long" in exc_info.value.message @@ -70,16 +70,6 @@ async def client_connected( await cleanup_server(server) -async def test_protocol_connected_check_on_read_response( - monkeypatch: pytest.MonkeyPatch, -) -> None: - protocol = SMTPProtocol() - monkeypatch.setattr(protocol, "transport", None) - - with pytest.raises(SMTPServerDisconnected): - await protocol.read_response(timeout=1.0) - - async def test_protocol_read_only_transport_error() -> None: event_loop = asyncio.get_running_loop() read_descriptor, _ = os.pipe() @@ -87,10 +77,10 @@ async def test_protocol_read_only_transport_error() -> None: connect_future = event_loop.connect_read_pipe(SMTPProtocol, read_pipe) transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - assert getattr(protocol, "transport", None) is transport + assert getattr(protocol, "_transport", None) is transport - with pytest.raises(RuntimeError, match="does not support writing"): - protocol.write(b"TEST\n") + with pytest.raises(AttributeError, match="object has no attribute 'write'"): + await protocol.write(b"TEST\n") transport.close() @@ -104,16 +94,6 @@ async def test_protocol_connected_check_on_start_tls( await smtp_protocol.start_tls(client_tls_context, timeout=1.0) -async def test_protocol_already_over_tls_check_on_start_tls( - client_tls_context: ssl.SSLContext, -) -> None: - smtp_protocol = SMTPProtocol() - smtp_protocol._over_ssl = True - - with pytest.raises(RuntimeError, match="Already using TLS"): - await smtp_protocol.start_tls(client_tls_context) - - async def test_protocol_connection_reset_on_starttls( hostname: str, smtpd_server_port: int, @@ -138,30 +118,6 @@ def mock_start_tls(*args, **kwargs) -> None: transport.close() -async def test_protocol_timeout_on_starttls( - hostname: str, - smtpd_server_port: int, - client_tls_context: ssl.SSLContext, - monkeypatch: pytest.MonkeyPatch, -) -> None: - event_loop = asyncio.get_running_loop() - - connect_future = event_loop.create_connection( - SMTPProtocol, host=hostname, port=smtpd_server_port - ) - transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - - def mock_start_tls(*args, **kwargs) -> None: - raise TimeoutError("Timed out") - - monkeypatch.setattr(event_loop, "start_tls", mock_start_tls) - - with pytest.raises(SMTPTimeoutError, match="Timed out while upgrading transport"): - await protocol.start_tls(client_tls_context) - - transport.close() - - async def test_error_on_readline_with_partial_line( bind_address: str, hostname: str ) -> None: @@ -187,7 +143,7 @@ async def client_connected( _, protocol = await asyncio.wait_for(connect_future, timeout=1.0) with pytest.raises(SMTPServerDisconnected): - await protocol.read_response(timeout=1.0) # type: ignore + await asyncio.wait_for(protocol.read_response(), 1.0) server.close() await cleanup_server(server) @@ -220,41 +176,7 @@ async def client_connected( with pytest.raises( SMTPResponseException, match="Malformed SMTP response line: ERROR" ): - await protocol.read_response(timeout=1.0) # type: ignore - - server.close() - await cleanup_server(server) - - -async def test_protocol_response_waiter_unset( - bind_address: str, - hostname: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - event_loop = asyncio.get_running_loop() - - async def client_connected( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - await reader.read(1000) - writer.write(b"220 Hi\r\n") - await writer.drain() - - server = await asyncio.start_server( - client_connected, host=bind_address, port=0, family=socket.AF_INET - ) - server_port = server.sockets[0].getsockname()[1] if server.sockets else 0 - - connect_future = event_loop.create_connection( - SMTPProtocol, host=hostname, port=server_port - ) - - _, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - - monkeypatch.setattr(protocol, "_response_waiter", None) - - with pytest.raises(SMTPServerDisconnected): - await protocol.execute_command(b"TEST\n", timeout=1.0) # type: ignore + await asyncio.wait_for(protocol.read_response(), 1.0) server.close() await cleanup_server(server) @@ -287,7 +209,7 @@ async def client_connected( _, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - response = await protocol.execute_command(b"TEST\n", timeout=1.0) # type: ignore + response = await protocol.execute_command(b"TEST\n") assert response.code == 220 assert response.message == "Hi" @@ -347,8 +269,8 @@ async def client_connected( ) _, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - await protocol.execute_command(b"HELO\n", timeout=1.0) - await protocol.execute_command(b"QUIT\n", timeout=1.0) + await protocol.execute_command(b"HELO\n") + await protocol.execute_command(b"QUIT\n") del protocol # Force garbage collection diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index 4a86f36..7b65331 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -3,7 +3,6 @@ """ import asyncio -import socket import ssl import pytest @@ -16,11 +15,13 @@ ) from aiosmtplib.protocol import SMTPProtocol -from .compat import cleanup_server -from .smtpd import mock_response_delayed_ok, mock_response_delayed_read +from .smtpd import ( + mock_response_delayed_ok_with_cleanup, + mock_response_delayed_read_with_cleanup, +) -@pytest.mark.smtpd_mocks(smtp_EHLO=mock_response_delayed_ok) +@pytest.mark.smtpd_mocks(smtp_EHLO=mock_response_delayed_ok_with_cleanup) async def test_command_timeout_error(smtp_client: SMTP) -> None: await smtp_client.connect() @@ -28,7 +29,7 @@ async def test_command_timeout_error(smtp_client: SMTP) -> None: await smtp_client.ehlo(hostname="example.com", timeout=0.0) -@pytest.mark.smtpd_mocks(smtp_DATA=mock_response_delayed_ok) +@pytest.mark.smtpd_mocks(smtp_DATA=mock_response_delayed_ok_with_cleanup) async def test_data_timeout_error(smtp_client: SMTP) -> None: await smtp_client.connect() await smtp_client.ehlo() @@ -38,23 +39,22 @@ async def test_data_timeout_error(smtp_client: SMTP) -> None: await smtp_client.data("HELLO WORLD", timeout=0.0) -@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_ok) +@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_ok_with_cleanup) async def test_timeout_error_on_connect(smtp_client: SMTP) -> None: with pytest.raises(SMTPTimeoutError): await smtp_client.connect(timeout=0.0) - assert smtp_client.transport is None assert smtp_client.protocol is None -@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_read) +@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_read_with_cleanup) async def test_timeout_on_initial_read(smtp_client: SMTP) -> None: with pytest.raises(SMTPTimeoutError): # We need to use a timeout > 0 here to avoid timing out on connect await smtp_client.connect(timeout=0.01) -@pytest.mark.smtpd_mocks(smtp_STARTTLS=mock_response_delayed_ok) +@pytest.mark.smtpd_mocks(smtp_STARTTLS=mock_response_delayed_ok_with_cleanup) async def test_timeout_on_starttls(smtp_client: SMTP) -> None: await smtp_client.connect() await smtp_client.ehlo() @@ -63,27 +63,6 @@ async def test_timeout_on_starttls(smtp_client: SMTP) -> None: await smtp_client.starttls(timeout=0.0) -async def test_protocol_read_response_with_timeout_times_out( - echo_server: asyncio.AbstractServer, - hostname: str, - echo_server_port: int, -) -> None: - event_loop = asyncio.get_running_loop() - - connect_future = event_loop.create_connection( - SMTPProtocol, host=hostname, port=echo_server_port - ) - - transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - - with pytest.raises(SMTPTimeoutError) as exc: - await protocol.read_response(timeout=0.0) # type: ignore - - transport.close() - - assert str(exc.value) == "Timed out waiting for server response" - - async def test_connect_timeout_error(hostname: str, unused_tcp_port: int) -> None: client = SMTP(hostname=hostname, port=unused_tcp_port, timeout=0.0) @@ -110,37 +89,6 @@ async def test_server_disconnected_error_after_connect_timeout( await client.sendmail(sender_str, [recipient_str], message_str) -async def test_protocol_timeout_on_starttls( - bind_address: str, - hostname: str, - client_tls_context: ssl.SSLContext, -) -> None: - event_loop = asyncio.get_running_loop() - - async def client_connected( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - await asyncio.sleep(1.0) - - server = await asyncio.start_server( - client_connected, host=bind_address, port=0, family=socket.AF_INET - ) - server_port = server.sockets[0].getsockname()[1] if server.sockets else 0 - - connect_future = event_loop.create_connection( - SMTPProtocol, host=hostname, port=server_port - ) - - _, protocol = await asyncio.wait_for(connect_future, timeout=1.0) - - with pytest.raises(SMTPTimeoutError): - # STARTTLS timeout must be > 0 - await protocol.start_tls(client_tls_context, timeout=0.00001) # type: ignore - - server.close() - await cleanup_server(server) - - async def test_protocol_connection_aborted_on_starttls( hostname: str, smtpd_server_port: int, diff --git a/tests/test_tls.py b/tests/test_tls.py index 037020e..2031140 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -48,9 +48,7 @@ async def test_starttls(smtp_client: SMTP) -> None: assert not smtp_client.supported_auth_methods assert not smtp_client.supports_esmtp - # Make sure our connection was actually upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is not None response = await smtp_client.ehlo() assert response.code == SMTPStatus.completed @@ -68,18 +66,14 @@ async def test_starttls_init_kwarg( ) async with smtp_client: - # Make sure our connection was actually upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is not None @pytest.mark.smtpd_options(tls=False) async def test_starttls_connect_kwarg(smtp_client: SMTP) -> None: await smtp_client.connect(start_tls=True) - # Make sure our connection was actually upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is not None await smtp_client.quit() @@ -96,9 +90,7 @@ async def test_starttls_auto( ) async with smtp_client: - # Make sure our connection was actually upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is not None async def test_starttls_auto_connect_kwarg( @@ -116,9 +108,7 @@ async def test_starttls_auto_connect_kwarg( await smtp_client.connect(start_tls=None) - # Make sure our connection was actually upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is not None await smtp_client.quit() @@ -130,9 +120,7 @@ async def test_starttls_auto_connect_not_supported(smtp_client: SMTP) -> None: async with smtp_client: await smtp_client.ehlo() - # Make sure our connection was nul upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" not in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is None @pytest.mark.smtpd_options(tls=False) @@ -190,9 +178,7 @@ async def test_starttls_invalid_responses(smtp_client: SMTP) -> None: assert smtp_client.esmtp_extensions == old_extensions assert smtp_client.supports_esmtp is True - # Make sure our connection was not upgraded. ssl protocol transport is - # private in UVloop, so just check the class name. - assert "SSL" not in type(smtp_client.transport).__name__ + assert smtp_client.get_transport_info("sslcontext") is None @pytest.mark.smtpd_options(tls=False)