11import test .echotestrpc .model as model
22import test .echotestrpc .client as client
3- from awsiot .eventstreamrpc import (Connection , Header , LifecycleHandler ,
4- MessageAmendment , SerializeError , StreamResponseHandler )
3+ from awsiot .eventstreamrpc import (
4+ AccessDeniedError ,
5+ Connection ,
6+ ConnectionClosedError ,
7+ EventStreamError ,
8+ LifecycleHandler ,
9+ MessageAmendment ,
10+ SerializeError ,
11+ StreamClosedError ,
12+ StreamResponseHandler )
513from awscrt .io import (ClientBootstrap , DefaultHostResolver , EventLoopGroup ,
614 init_logging , LogLevel )
15+ from awscrt .eventstream import Header , HeaderType
16+ from awscrt .eventstream .rpc import MessageType
717from datetime import datetime , timezone
818import logging
919import os
1020from queue import Queue
1121from sys import stderr
1222from threading import Event
23+ from time import sleep
1324from typing import Optional , Sequence
1425from unittest import skipUnless , TestCase
1526
@@ -29,6 +40,7 @@ def __init__(self, on_freakout):
2940 self .disconnect_event = Event ()
3041 self .disconnect_reason = None
3142 self .errors = Queue ()
43+ self .error_callback_return_val = True
3244 self .pings = Queue ()
3345 # if something happens out of order, call this
3446 self ._freakout = on_freakout
@@ -54,6 +66,7 @@ def on_disconnect(self, reason: Optional[Exception]):
5466
5567 def on_error (self , error : Exception ) -> bool :
5668 self .errors .put (error )
69+ return self .error_callback_return_val
5770
5871 def on_ping (self , headers : Sequence [Header ], payload : bytes ):
5972 self .pings .put ({'headers' : headers , 'payload' : payload })
@@ -99,11 +112,6 @@ def connect_amender():
99112 return MessageAmendment (headers = headers )
100113
101114
102- def bad_connect_amender ():
103- headers = [Header .from_string ('client-name' , 'rejected.testy_mc_failureson' )]
104- return MessageAmendment (headers = headers )
105-
106-
107115@skipUnless (EVENTSTREAM_ECHO_TEST , "Skipping until we have permanent echo server" )
108116class RpcTest (TestCase ):
109117 def _on_handler_freakout (self , msg ):
@@ -114,7 +122,7 @@ def _on_handler_freakout(self, msg):
114122 def _assertNoHandlerFreakout (self ):
115123 self .assertIsNone (getattr (self , 'freakout_msg' , None ))
116124
117- def _connect (self ):
125+ def setUp (self ):
118126 elg = EventLoopGroup ()
119127 resolver = DefaultHostResolver (elg )
120128 bootstrap = ClientBootstrap (elg , resolver )
@@ -123,23 +131,21 @@ def _connect(self):
123131 port = 8033 ,
124132 bootstrap = bootstrap ,
125133 connect_message_amender = connect_amender )
126- self .lifecycle_handler = ConnectionLifecycleHandler (self ._on_handler_freakout )
134+
135+ def _connect (self , lifecycle_handler = None ):
136+ if lifecycle_handler :
137+ self .lifecycle_handler = lifecycle_handler
138+ else :
139+ self .lifecycle_handler = ConnectionLifecycleHandler (self ._on_handler_freakout )
127140 connect_future = self .connection .connect (self .lifecycle_handler )
128141 connect_future .result (TIMEOUT )
129-
130142 self .echo_client = client .EchoTestRPCClient (self .connection )
131143
132- def _bad_connect (self , bad_host = False , bad_client_name = False ):
133- elg = EventLoopGroup ()
134- resolver = DefaultHostResolver (elg )
135- bootstrap = ClientBootstrap (elg , resolver )
136- host_name = 'badhostname' if bad_host else '127.0.0.1'
137- amender = bad_connect_amender if bad_client_name else connect_amender
138- self .connection = Connection (
139- host_name = host_name ,
140- port = 8033 ,
141- bootstrap = bootstrap ,
142- connect_message_amender = amender )
144+ def _bad_connect (self , bad_host = False , amender = None ):
145+ if bad_host :
146+ self .connection .host_name = 'badhostname'
147+ if amender :
148+ self .connection ._connect_message_amender = amender
143149 self .lifecycle_handler = ConnectionLifecycleHandler (self ._on_handler_freakout )
144150 connect_future = self .connection .connect (self .lifecycle_handler )
145151 connect_exception = connect_future .exception (TIMEOUT )
@@ -154,13 +160,44 @@ def _bad_connect(self, bad_host=False, bad_client_name=False):
154160
155161 self ._assertNoHandlerFreakout ()
156162
163+ return connect_exception
164+
165+ def _close_connection (self ):
166+ # helper to do normal close of healthy connection
167+ close_future = self .connection .close ()
168+ close_exception = close_future .exception (TIMEOUT )
169+ self .assertIsNone (close_exception )
170+ self .assertTrue (self .lifecycle_handler .disconnect_event .wait (TIMEOUT ))
171+ self .assertIsNone (self .lifecycle_handler .disconnect_reason )
172+ self .assertTrue (self .lifecycle_handler .errors .empty ())
173+ self ._assertNoHandlerFreakout ()
174+
157175 def test_connect_failed_socket (self ):
158176 # test failure from the CONNECTING_TO_SOCKET phase
159177 self ._bad_connect (bad_host = True )
160178
161179 def test_connect_failed_connack (self ):
162- # test failure from the WAITING_FOR_CONNECT_ACK phse
163- self ._bad_connect (bad_client_name = True )
180+ # test failure from the WAITING_FOR_CONNECT_ACK phase
181+ def _amender ():
182+ headers = [Header .from_string ('client-name' , 'rejected.testy_mc_failureson' )]
183+ return MessageAmendment (headers = headers )
184+ exception = self ._bad_connect (amender = _amender )
185+ self .assertIsInstance (exception , AccessDeniedError )
186+
187+ def test_connect_failed_amender_exception (self ):
188+ # test failure due to connect_amender exception
189+ error = RuntimeError ('Purposefully raising error in amender callback' )
190+
191+ def _amender ():
192+ raise error
193+ exception = self ._bad_connect (amender = _amender )
194+ self .assertIs (exception , error )
195+
196+ def test_connect_failed_amender_bad_return (self ):
197+ # test failure due to amender returning bad data
198+ def _amender ():
199+ return 'a string is not a MessageAmendment'
200+ self ._bad_connect (amender = _amender )
164201
165202 def test_echo_message (self ):
166203 self ._connect ()
@@ -214,11 +251,7 @@ def test_echo_message(self):
214251 # and timezone info due to datetime->timestamp->datetime conversion
215252 self .assertEqual (request .message , response .message )
216253
217- # must close connection
218- close_future = self .connection .close ()
219- self .assertIsNone (close_future .exception (TIMEOUT ))
220-
221- self ._assertNoHandlerFreakout ()
254+ self ._close_connection ()
222255
223256 def test_bad_activate (self ):
224257 self ._connect ()
@@ -231,18 +264,14 @@ def test_bad_activate(self):
231264 with self .assertRaises (SerializeError ):
232265 operation .activate (bad_request )
233266
234- # must close connection
235- close_future = self .connection .close ()
236- self .assertIsNone (close_future .exception (TIMEOUT ))
237-
238- self ._assertNoHandlerFreakout ()
267+ self ._close_connection ()
239268
240- def test_echo_streaming_message (self ):
269+ def test_echo_stream_messages (self ):
241270 self ._connect ()
242271
243- handler = StreamHandler (self ._on_handler_freakout )
244- operation = self .echo_client .new_echo_stream_messages (handler )
245- handler .operation = operation
272+ stream_handler = StreamHandler (self ._on_handler_freakout )
273+ operation = self .echo_client .new_echo_stream_messages (stream_handler )
274+ stream_handler .operation = operation
246275
247276 # send initial request
248277 flush = operation .activate (model .EchoStreamingRequest ())
@@ -254,22 +283,16 @@ def test_echo_streaming_message(self):
254283 flush .result (TIMEOUT )
255284
256285 # recv streaming response
257- response_event = handler .events .get (timeout = TIMEOUT )
286+ response_event = stream_handler .events .get (timeout = TIMEOUT )
258287 self .assertEqual (request_event , response_event )
259288
260- # must close connection
261- close_future = self .connection .close ()
262- self .assertIsNone (close_future .exception (TIMEOUT ))
263- self .assertTrue (handler .closed .is_set ())
264-
265- # make sure nothing went wrong that we didn't expect to go wrong
266- self .assertTrue (handler .errors .empty ())
267- self ._assertNoHandlerFreakout ()
289+ self ._close_connection ()
290+ self .assertTrue (stream_handler .closed .is_set ())
291+ self .assertTrue (stream_handler .errors .empty ())
268292
269293 def test_cause_service_error (self ):
270294 # test the CauseServiceError operation,
271295 # which always responds with a ServiceError
272- # and then terminates the connection
273296 self ._connect ()
274297
275298 operation = self .echo_client .new_cause_service_error ()
@@ -281,7 +304,155 @@ def test_cause_service_error(self):
281304 response_exception = operation .get_response ().exception (TIMEOUT )
282305 self .assertIsInstance (response_exception , model .ServiceError )
283306
284- # close connection
285- close_future = self .connection .close ()
286- self .assertIsNone (close_future .exception (TIMEOUT ))
287- self ._assertNoHandlerFreakout ()
307+ self ._close_connection ()
308+
309+ def test_cause_stream_service_to_error (self ):
310+ # test CauseStreamServiceToError operation,
311+ # Responds to initial request normally then throws a ServiceError on stream response
312+ self ._connect ()
313+
314+ # set up operation
315+ stream_handler = StreamHandler (self ._on_handler_freakout )
316+ stream_handler .error_callback_return_val = False
317+ op = self .echo_client .new_cause_stream_service_to_error (stream_handler )
318+ stream_handler .operation = op
319+
320+ # send initial request, normal response should come back
321+ request = model .EchoStreamingRequest ()
322+ op .activate (request )
323+ op .get_response ().result ()
324+
325+ # send subsequent streaming message, streaming error should come back
326+ msg_to_send = model .EchoStreamingMessage (stream_message = model .MessageData ())
327+ op .send_stream_event (msg_to_send )
328+
329+ stream_error = stream_handler .errors .get (timeout = TIMEOUT )
330+ self .assertIsInstance (stream_error , model .ServiceError )
331+
332+ self ._close_connection ()
333+ self .assertTrue (stream_handler .closed .is_set ())
334+ self .assertTrue (stream_handler .errors .empty ())
335+
336+ def test_connection_error (self ):
337+ # test that everything acts as expected if server sends
338+ # connection-level error
339+ self ._connect ()
340+
341+ # reach deep into private inner workings of the connection to manually
342+ # send a bad message to the server.
343+ self .connection ._synced .current_connection .send_protocol_message (
344+ headers = [Header .from_int32 (':stream-id' , - 999 )],
345+ message_type = MessageType .APPLICATION_MESSAGE ,
346+ )
347+
348+ # should receive PROTOCOL_ERROR in response to bad message
349+ error = self .lifecycle_handler .errors .get (timeout = TIMEOUT )
350+ self .assertIsInstance (error , EventStreamError )
351+
352+ # server kills connection after PROTOCOL_ERROR
353+ self .assertTrue (self .lifecycle_handler .disconnect_event .wait (TIMEOUT ))
354+ self .assertIsInstance (self .lifecycle_handler .disconnect_reason , EventStreamError )
355+
356+ def test_close_with_reason (self ):
357+ # test that, if an error is passed to connection.close(err),
358+ # it carries through
359+ self ._connect ()
360+
361+ my_error = RuntimeError ('my close reason' )
362+ close_future = self .connection .close (my_error )
363+ close_reason = close_future .exception (TIMEOUT )
364+
365+ self .assertIs (my_error , close_reason )
366+ self .assertTrue (self .lifecycle_handler .disconnect_event .wait (TIMEOUT ))
367+ self .assertIs (my_error , self .lifecycle_handler .disconnect_reason )
368+
369+ def test_reconnect (self ):
370+ # test that a Connection can connect and disconnect multiple times
371+ self ._connect ()
372+ self ._close_connection ()
373+
374+ self ._connect ()
375+ self ._close_connection ()
376+
377+ def test_close_during_setup (self ):
378+ # Test that it's safe to call close() while the connection is still setting up.
379+
380+ # There are multiple stages to the async connect() and we'd like
381+ # to stress close() being called in each of these phases.
382+ # Hacky strategy to achieve this to, in a loop:
383+ # - call async connect()
384+ # - after some delay, call async close()
385+ # - with each loop, the delay gets slightly longer
386+ # - break out of loop loop once the delay is long enough that the
387+ # connect() is completing before we ever call close()
388+ delay_increment_sec = 0.005
389+ stop_after_n_successful_connections = 2
390+
391+ delay_sec = 0.0
392+ successful_connections = 0
393+ while successful_connections < stop_after_n_successful_connections :
394+ # not using helper _connect() call because it blocks until async connect() completes
395+ self .lifecycle_handler = ConnectionLifecycleHandler (self ._on_handler_freakout )
396+ connect_future = self .connection .connect (self .lifecycle_handler )
397+
398+ if delay_sec > 0.0 :
399+ sleep (delay_sec )
400+ close_future = self .connection .close ()
401+
402+ # wait for connect and close to complete
403+ connect_exception = connect_future .exception (TIMEOUT )
404+ close_exception = close_future .exception (TIMEOUT )
405+
406+ # close should have been clean
407+ self .assertIsNone (close_exception )
408+
409+ # connect might have succeeded, or might have failed,
410+ # depending on the timing of this thread's close() call
411+ if connect_exception :
412+ self .assertIsInstance (connect_exception , ConnectionClosedError )
413+ # lifecycle handlers should NOT fire if connect setup failed
414+ # wait a tiny bit to be 100% sure these never fire
415+ self .assertFalse (self .lifecycle_handler .connect_event .wait (0.1 ))
416+ self .assertFalse (self .lifecycle_handler .disconnect_event .wait (0.1 ))
417+ else :
418+ self .assertTrue (self .lifecycle_handler .connect_event .wait (TIMEOUT ))
419+ self .assertTrue (self .lifecycle_handler .disconnect_event .wait (TIMEOUT ))
420+ successful_connections += 1
421+
422+ delay_sec += delay_increment_sec
423+ self ._assertNoHandlerFreakout ()
424+
425+ def test_operation_response_completes_if_connection_closed (self ):
426+ # test that response future completes if connection is closed
427+ # before actual response is received.
428+
429+ # this test is timing dependent, the response could theoretically
430+ # come on another thread before this thread can close the connection,
431+ # so run test in a loop till we get the timing we want
432+ closed_before_response = False
433+
434+ # give up after a reasonably high number of tries
435+ # (note: first try always passes on my 2019 macbook pro, with localhost server)
436+ tries = 0
437+ max_tries = 100
438+
439+ while not closed_before_response :
440+ self .assertLess (tries , max_tries , "Test couldn't get result it wanted after many tries" )
441+ tries += 1
442+
443+ self ._connect ()
444+
445+ stream_handler = StreamHandler (self ._on_handler_freakout )
446+ operation = self .echo_client .new_echo_stream_messages (stream_handler )
447+ stream_handler .operation = operation
448+
449+ operation .activate (model .EchoStreamingRequest ())
450+ close_future = self .connection .close ()
451+
452+ try :
453+ response = operation .get_response ().result (TIMEOUT )
454+ except StreamClosedError :
455+ closed_before_response = True
456+
457+ # wait for close to complete before attempting reconnect
458+ close_future .result (TIMEOUT )
0 commit comments