@@ -53,7 +53,8 @@ def connect(host="localhost", user=None, password="",
5353 connect_timeout = None , read_default_group = None ,
5454 autocommit = False , echo = False ,
5555 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56- program_name = '' , server_public_key = None ):
56+ program_name = '' , server_public_key = None ,
57+ read_timeout = None ):
5758 """See connections.Connection.__init__() for information about
5859 defaults."""
5960 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
6667 read_default_group = read_default_group ,
6768 autocommit = autocommit , echo = echo ,
6869 local_infile = local_infile , loop = loop , ssl = ssl ,
69- auth_plugin = auth_plugin , program_name = program_name )
70+ auth_plugin = auth_plugin , program_name = program_name ,
71+ read_timeout = read_timeout )
7072 return _ConnectionContextManager (coro )
7173
7274
@@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
142144 connect_timeout = None , read_default_group = None ,
143145 autocommit = False , echo = False ,
144146 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145- program_name = '' , server_public_key = None ):
147+ program_name = '' , server_public_key = None , read_timeout = None ):
146148 """
147149 Establish a connection to the MySQL database. Accepts several
148150 arguments:
@@ -184,6 +186,8 @@ def __init__(self, host="localhost", user=None, password="",
184186 handshaking with MySQL. (omitted by default)
185187 :param server_public_key: SHA256 authentication plugin public
186188 key value.
189+ :param read_timeout: The timeout for reading from the connection in seconds
190+ (default: None - no timeout)
187191 :param loop: asyncio loop
188192 """
189193 self ._loop = loop or asyncio .get_event_loop ()
@@ -257,6 +261,7 @@ def __init__(self, host="localhost", user=None, password="",
257261
258262 self .cursorclass = cursorclass
259263 self .connect_timeout = connect_timeout
264+ self .read_timeout = read_timeout
260265
261266 self ._result = None
262267 self ._affected_rows = 0
@@ -654,12 +659,25 @@ async def _read_packet(self, packet_type=MysqlPacket):
654659
655660 async def _read_bytes (self , num_bytes ):
656661 try :
657- data = await self ._reader .readexactly (num_bytes )
662+ if self .read_timeout :
663+ try :
664+ data = await asyncio .wait_for (
665+ self ._reader .readexactly (num_bytes ),
666+ self .read_timeout
667+ )
668+ except asyncio .TimeoutError as e :
669+ raise asyncio .TimeoutError ("Read timeout exceeded" ) from e
670+ else :
671+ data = await self ._reader .readexactly (num_bytes )
658672 except asyncio .IncompleteReadError as e :
659673 msg = "Lost connection to MySQL server during query"
660674 self .close ()
661675 raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
662- except OSError as e :
676+ except (OSError , asyncio .TimeoutError ) as e :
677+ msg = f"Lost connection to MySQL server during query ({ e } )"
678+ self .close ()
679+ raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
680+ except Exception as e :
663681 msg = f"Lost connection to MySQL server during query ({ e } )"
664682 self .close ()
665683 raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
0 commit comments