@@ -51,6 +51,7 @@ def connect(host="localhost", user=None, password="",
5151 read_default_file = None , conv = decoders , use_unicode = None ,
5252 client_flag = 0 , cursorclass = Cursor , init_command = None ,
5353 connect_timeout = None , read_default_group = None ,
54+ read_timeout = None ,
5455 autocommit = False , echo = False ,
5556 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
5657 program_name = '' , server_public_key = None ):
@@ -64,6 +65,7 @@ def connect(host="localhost", user=None, password="",
6465 init_command = init_command ,
6566 connect_timeout = connect_timeout ,
6667 read_default_group = read_default_group ,
68+ read_timeout = read_timeout ,
6769 autocommit = autocommit , echo = echo ,
6870 local_infile = local_infile , loop = loop , ssl = ssl ,
6971 auth_plugin = auth_plugin , program_name = program_name )
@@ -139,7 +141,7 @@ def __init__(self, host="localhost", user=None, password="",
139141 charset = '' , sql_mode = None ,
140142 read_default_file = None , conv = decoders , use_unicode = None ,
141143 client_flag = 0 , cursorclass = Cursor , init_command = None ,
142- connect_timeout = None , read_default_group = None ,
144+ connect_timeout = None , read_default_group = None , read_timeout = None ,
143145 autocommit = False , echo = False ,
144146 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145147 program_name = '' , server_public_key = None ):
@@ -171,6 +173,8 @@ def __init__(self, host="localhost", user=None, password="",
171173 when connecting.
172174 :param read_default_group: Group to read from in the configuration
173175 file.
176+ :param read_timeout: The timeout for reading from the connection in seconds
177+ (default: None - no timeout)
174178 :param autocommit: Autocommit mode. None means use server default.
175179 (default: False)
176180 :param local_infile: boolean to enable the use of LOAD DATA LOCAL
@@ -257,6 +261,7 @@ def __init__(self, host="localhost", user=None, password="",
257261
258262 self .cursorclass = cursorclass
259263 self .connect_timeout = connect_timeout
264+ self .read_timeout = read_timeout
260265
261266 self ._result = None
262267 self ._affected_rows = 0
@@ -654,12 +659,25 @@ async def _read_packet(self, packet_type=MysqlPacket):
654659
655660 async def _read_bytes (self , num_bytes ):
656661 try :
657- data = await self ._reader .readexactly (num_bytes )
662+ if self .read_timeout :
663+ try :
664+ data = await asyncio .wait_for (
665+ self ._reader .readexactly (num_bytes ),
666+ self .read_timeout
667+ )
668+ except asyncio .TimeoutError as e :
669+ raise asyncio .TimeoutError ("Read timeout exceeded" ) from e
670+ else :
671+ data = await self ._reader .readexactly (num_bytes )
658672 except asyncio .IncompleteReadError as e :
659673 msg = "Lost connection to MySQL server during query"
660674 self .close ()
661675 raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
662- except OSError as e :
676+ except (OSError , asyncio .TimeoutError ) as e :
677+ msg = f"Lost connection to MySQL server during query ({ e } )"
678+ self .close ()
679+ raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
680+ except Exception as e :
663681 msg = f"Lost connection to MySQL server during query ({ e } )"
664682 self .close ()
665683 raise OperationalError (CR .CR_SERVER_LOST , msg ) from e
0 commit comments