Skip to content

Commit 1c0cf45

Browse files
authored
Add more eventstreamrpc tests (#137)
1 parent f0716c0 commit 1c0cf45

File tree

2 files changed

+223
-52
lines changed

2 files changed

+223
-52
lines changed

awsiot/eventstreamrpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def on_connection_shutdown(self, reason, **kwargs):
270270
# if user called close() without a reason,
271271
# set a reason that the setup_future has failed
272272
if reason is None:
273-
reason = RuntimeError("close() called during connection setup")
273+
reason = ConnectionClosedError("close() called during connection setup")
274274
logger.error("%r connect failed: %r", self.owner, reason)
275275
connect_future.set_exception(reason)
276276
else:

test/test_rpc.py

Lines changed: 222 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
import test.echotestrpc.model as model
22
import test.echotestrpc.client as client
3-
from awsiot.eventstreamrpc import (Connection, Header, LifecycleHandler,
4-
MessageAmendment, SerializeError, StreamResponseHandler)
3+
from awsiot.eventstreamrpc import (
4+
AccessDeniedError,
5+
Connection,
6+
ConnectionClosedError,
7+
EventStreamError,
8+
LifecycleHandler,
9+
MessageAmendment,
10+
SerializeError,
11+
StreamClosedError,
12+
StreamResponseHandler)
513
from awscrt.io import (ClientBootstrap, DefaultHostResolver, EventLoopGroup,
614
init_logging, LogLevel)
15+
from awscrt.eventstream import Header, HeaderType
16+
from awscrt.eventstream.rpc import MessageType
717
from datetime import datetime, timezone
818
import logging
919
import os
1020
from queue import Queue
1121
from sys import stderr
1222
from threading import Event
23+
from time import sleep
1324
from typing import Optional, Sequence
1425
from unittest import skipUnless, TestCase
1526

@@ -29,6 +40,7 @@ def __init__(self, on_freakout):
2940
self.disconnect_event = Event()
3041
self.disconnect_reason = None
3142
self.errors = Queue()
43+
self.error_callback_return_val = True
3244
self.pings = Queue()
3345
# if something happens out of order, call this
3446
self._freakout = on_freakout
@@ -54,6 +66,7 @@ def on_disconnect(self, reason: Optional[Exception]):
5466

5567
def on_error(self, error: Exception) -> bool:
5668
self.errors.put(error)
69+
return self.error_callback_return_val
5770

5871
def on_ping(self, headers: Sequence[Header], payload: bytes):
5972
self.pings.put({'headers': headers, 'payload': payload})
@@ -99,11 +112,6 @@ def connect_amender():
99112
return MessageAmendment(headers=headers)
100113

101114

102-
def bad_connect_amender():
103-
headers = [Header.from_string('client-name', 'rejected.testy_mc_failureson')]
104-
return MessageAmendment(headers=headers)
105-
106-
107115
@skipUnless(EVENTSTREAM_ECHO_TEST, "Skipping until we have permanent echo server")
108116
class RpcTest(TestCase):
109117
def _on_handler_freakout(self, msg):
@@ -114,7 +122,7 @@ def _on_handler_freakout(self, msg):
114122
def _assertNoHandlerFreakout(self):
115123
self.assertIsNone(getattr(self, 'freakout_msg', None))
116124

117-
def _connect(self):
125+
def setUp(self):
118126
elg = EventLoopGroup()
119127
resolver = DefaultHostResolver(elg)
120128
bootstrap = ClientBootstrap(elg, resolver)
@@ -123,23 +131,21 @@ def _connect(self):
123131
port=8033,
124132
bootstrap=bootstrap,
125133
connect_message_amender=connect_amender)
126-
self.lifecycle_handler = ConnectionLifecycleHandler(self._on_handler_freakout)
134+
135+
def _connect(self, lifecycle_handler=None):
136+
if lifecycle_handler:
137+
self.lifecycle_handler = lifecycle_handler
138+
else:
139+
self.lifecycle_handler = ConnectionLifecycleHandler(self._on_handler_freakout)
127140
connect_future = self.connection.connect(self.lifecycle_handler)
128141
connect_future.result(TIMEOUT)
129-
130142
self.echo_client = client.EchoTestRPCClient(self.connection)
131143

132-
def _bad_connect(self, bad_host=False, bad_client_name=False):
133-
elg = EventLoopGroup()
134-
resolver = DefaultHostResolver(elg)
135-
bootstrap = ClientBootstrap(elg, resolver)
136-
host_name = 'badhostname' if bad_host else '127.0.0.1'
137-
amender = bad_connect_amender if bad_client_name else connect_amender
138-
self.connection = Connection(
139-
host_name=host_name,
140-
port=8033,
141-
bootstrap=bootstrap,
142-
connect_message_amender=amender)
144+
def _bad_connect(self, bad_host=False, amender=None):
145+
if bad_host:
146+
self.connection.host_name = 'badhostname'
147+
if amender:
148+
self.connection._connect_message_amender = amender
143149
self.lifecycle_handler = ConnectionLifecycleHandler(self._on_handler_freakout)
144150
connect_future = self.connection.connect(self.lifecycle_handler)
145151
connect_exception = connect_future.exception(TIMEOUT)
@@ -154,13 +160,44 @@ def _bad_connect(self, bad_host=False, bad_client_name=False):
154160

155161
self._assertNoHandlerFreakout()
156162

163+
return connect_exception
164+
165+
def _close_connection(self):
166+
# helper to do normal close of healthy connection
167+
close_future = self.connection.close()
168+
close_exception = close_future.exception(TIMEOUT)
169+
self.assertIsNone(close_exception)
170+
self.assertTrue(self.lifecycle_handler.disconnect_event.wait(TIMEOUT))
171+
self.assertIsNone(self.lifecycle_handler.disconnect_reason)
172+
self.assertTrue(self.lifecycle_handler.errors.empty())
173+
self._assertNoHandlerFreakout()
174+
157175
def test_connect_failed_socket(self):
158176
# test failure from the CONNECTING_TO_SOCKET phase
159177
self._bad_connect(bad_host=True)
160178

161179
def test_connect_failed_connack(self):
162-
# test failure from the WAITING_FOR_CONNECT_ACK phse
163-
self._bad_connect(bad_client_name=True)
180+
# test failure from the WAITING_FOR_CONNECT_ACK phase
181+
def _amender():
182+
headers = [Header.from_string('client-name', 'rejected.testy_mc_failureson')]
183+
return MessageAmendment(headers=headers)
184+
exception = self._bad_connect(amender=_amender)
185+
self.assertIsInstance(exception, AccessDeniedError)
186+
187+
def test_connect_failed_amender_exception(self):
188+
# test failure due to connect_amender exception
189+
error = RuntimeError('Purposefully raising error in amender callback')
190+
191+
def _amender():
192+
raise error
193+
exception = self._bad_connect(amender=_amender)
194+
self.assertIs(exception, error)
195+
196+
def test_connect_failed_amender_bad_return(self):
197+
# test failure due to amender returning bad data
198+
def _amender():
199+
return 'a string is not a MessageAmendment'
200+
self._bad_connect(amender=_amender)
164201

165202
def test_echo_message(self):
166203
self._connect()
@@ -214,11 +251,7 @@ def test_echo_message(self):
214251
# and timezone info due to datetime->timestamp->datetime conversion
215252
self.assertEqual(request.message, response.message)
216253

217-
# must close connection
218-
close_future = self.connection.close()
219-
self.assertIsNone(close_future.exception(TIMEOUT))
220-
221-
self._assertNoHandlerFreakout()
254+
self._close_connection()
222255

223256
def test_bad_activate(self):
224257
self._connect()
@@ -231,18 +264,14 @@ def test_bad_activate(self):
231264
with self.assertRaises(SerializeError):
232265
operation.activate(bad_request)
233266

234-
# must close connection
235-
close_future = self.connection.close()
236-
self.assertIsNone(close_future.exception(TIMEOUT))
237-
238-
self._assertNoHandlerFreakout()
267+
self._close_connection()
239268

240-
def test_echo_streaming_message(self):
269+
def test_echo_stream_messages(self):
241270
self._connect()
242271

243-
handler = StreamHandler(self._on_handler_freakout)
244-
operation = self.echo_client.new_echo_stream_messages(handler)
245-
handler.operation = operation
272+
stream_handler = StreamHandler(self._on_handler_freakout)
273+
operation = self.echo_client.new_echo_stream_messages(stream_handler)
274+
stream_handler.operation = operation
246275

247276
# send initial request
248277
flush = operation.activate(model.EchoStreamingRequest())
@@ -254,22 +283,16 @@ def test_echo_streaming_message(self):
254283
flush.result(TIMEOUT)
255284

256285
# recv streaming response
257-
response_event = handler.events.get(timeout=TIMEOUT)
286+
response_event = stream_handler.events.get(timeout=TIMEOUT)
258287
self.assertEqual(request_event, response_event)
259288

260-
# must close connection
261-
close_future = self.connection.close()
262-
self.assertIsNone(close_future.exception(TIMEOUT))
263-
self.assertTrue(handler.closed.is_set())
264-
265-
# make sure nothing went wrong that we didn't expect to go wrong
266-
self.assertTrue(handler.errors.empty())
267-
self._assertNoHandlerFreakout()
289+
self._close_connection()
290+
self.assertTrue(stream_handler.closed.is_set())
291+
self.assertTrue(stream_handler.errors.empty())
268292

269293
def test_cause_service_error(self):
270294
# test the CauseServiceError operation,
271295
# which always responds with a ServiceError
272-
# and then terminates the connection
273296
self._connect()
274297

275298
operation = self.echo_client.new_cause_service_error()
@@ -281,7 +304,155 @@ def test_cause_service_error(self):
281304
response_exception = operation.get_response().exception(TIMEOUT)
282305
self.assertIsInstance(response_exception, model.ServiceError)
283306

284-
# close connection
285-
close_future = self.connection.close()
286-
self.assertIsNone(close_future.exception(TIMEOUT))
287-
self._assertNoHandlerFreakout()
307+
self._close_connection()
308+
309+
def test_cause_stream_service_to_error(self):
310+
# test CauseStreamServiceToError operation,
311+
# Responds to initial request normally then throws a ServiceError on stream response
312+
self._connect()
313+
314+
# set up operation
315+
stream_handler = StreamHandler(self._on_handler_freakout)
316+
stream_handler.error_callback_return_val = False
317+
op = self.echo_client.new_cause_stream_service_to_error(stream_handler)
318+
stream_handler.operation = op
319+
320+
# send initial request, normal response should come back
321+
request = model.EchoStreamingRequest()
322+
op.activate(request)
323+
op.get_response().result()
324+
325+
# send subsequent streaming message, streaming error should come back
326+
msg_to_send = model.EchoStreamingMessage(stream_message=model.MessageData())
327+
op.send_stream_event(msg_to_send)
328+
329+
stream_error = stream_handler.errors.get(timeout=TIMEOUT)
330+
self.assertIsInstance(stream_error, model.ServiceError)
331+
332+
self._close_connection()
333+
self.assertTrue(stream_handler.closed.is_set())
334+
self.assertTrue(stream_handler.errors.empty())
335+
336+
def test_connection_error(self):
337+
# test that everything acts as expected if server sends
338+
# connection-level error
339+
self._connect()
340+
341+
# reach deep into private inner workings of the connection to manually
342+
# send a bad message to the server.
343+
self.connection._synced.current_connection.send_protocol_message(
344+
headers=[Header.from_int32(':stream-id', -999)],
345+
message_type=MessageType.APPLICATION_MESSAGE,
346+
)
347+
348+
# should receive PROTOCOL_ERROR in response to bad message
349+
error = self.lifecycle_handler.errors.get(timeout=TIMEOUT)
350+
self.assertIsInstance(error, EventStreamError)
351+
352+
# server kills connection after PROTOCOL_ERROR
353+
self.assertTrue(self.lifecycle_handler.disconnect_event.wait(TIMEOUT))
354+
self.assertIsInstance(self.lifecycle_handler.disconnect_reason, EventStreamError)
355+
356+
def test_close_with_reason(self):
357+
# test that, if an error is passed to connection.close(err),
358+
# it carries through
359+
self._connect()
360+
361+
my_error = RuntimeError('my close reason')
362+
close_future = self.connection.close(my_error)
363+
close_reason = close_future.exception(TIMEOUT)
364+
365+
self.assertIs(my_error, close_reason)
366+
self.assertTrue(self.lifecycle_handler.disconnect_event.wait(TIMEOUT))
367+
self.assertIs(my_error, self.lifecycle_handler.disconnect_reason)
368+
369+
def test_reconnect(self):
370+
# test that a Connection can connect and disconnect multiple times
371+
self._connect()
372+
self._close_connection()
373+
374+
self._connect()
375+
self._close_connection()
376+
377+
def test_close_during_setup(self):
378+
# Test that it's safe to call close() while the connection is still setting up.
379+
380+
# There are multiple stages to the async connect() and we'd like
381+
# to stress close() being called in each of these phases.
382+
# Hacky strategy to achieve this to, in a loop:
383+
# - call async connect()
384+
# - after some delay, call async close()
385+
# - with each loop, the delay gets slightly longer
386+
# - break out of loop loop once the delay is long enough that the
387+
# connect() is completing before we ever call close()
388+
delay_increment_sec = 0.005
389+
stop_after_n_successful_connections = 2
390+
391+
delay_sec = 0.0
392+
successful_connections = 0
393+
while successful_connections < stop_after_n_successful_connections:
394+
# not using helper _connect() call because it blocks until async connect() completes
395+
self.lifecycle_handler = ConnectionLifecycleHandler(self._on_handler_freakout)
396+
connect_future = self.connection.connect(self.lifecycle_handler)
397+
398+
if delay_sec > 0.0:
399+
sleep(delay_sec)
400+
close_future = self.connection.close()
401+
402+
# wait for connect and close to complete
403+
connect_exception = connect_future.exception(TIMEOUT)
404+
close_exception = close_future.exception(TIMEOUT)
405+
406+
# close should have been clean
407+
self.assertIsNone(close_exception)
408+
409+
# connect might have succeeded, or might have failed,
410+
# depending on the timing of this thread's close() call
411+
if connect_exception:
412+
self.assertIsInstance(connect_exception, ConnectionClosedError)
413+
# lifecycle handlers should NOT fire if connect setup failed
414+
# wait a tiny bit to be 100% sure these never fire
415+
self.assertFalse(self.lifecycle_handler.connect_event.wait(0.1))
416+
self.assertFalse(self.lifecycle_handler.disconnect_event.wait(0.1))
417+
else:
418+
self.assertTrue(self.lifecycle_handler.connect_event.wait(TIMEOUT))
419+
self.assertTrue(self.lifecycle_handler.disconnect_event.wait(TIMEOUT))
420+
successful_connections += 1
421+
422+
delay_sec += delay_increment_sec
423+
self._assertNoHandlerFreakout()
424+
425+
def test_operation_response_completes_if_connection_closed(self):
426+
# test that response future completes if connection is closed
427+
# before actual response is received.
428+
429+
# this test is timing dependent, the response could theoretically
430+
# come on another thread before this thread can close the connection,
431+
# so run test in a loop till we get the timing we want
432+
closed_before_response = False
433+
434+
# give up after a reasonably high number of tries
435+
# (note: first try always passes on my 2019 macbook pro, with localhost server)
436+
tries = 0
437+
max_tries = 100
438+
439+
while not closed_before_response:
440+
self.assertLess(tries, max_tries, "Test couldn't get result it wanted after many tries")
441+
tries += 1
442+
443+
self._connect()
444+
445+
stream_handler = StreamHandler(self._on_handler_freakout)
446+
operation = self.echo_client.new_echo_stream_messages(stream_handler)
447+
stream_handler.operation = operation
448+
449+
operation.activate(model.EchoStreamingRequest())
450+
close_future = self.connection.close()
451+
452+
try:
453+
response = operation.get_response().result(TIMEOUT)
454+
except StreamClosedError:
455+
closed_before_response = True
456+
457+
# wait for close to complete before attempting reconnect
458+
close_future.result(TIMEOUT)

0 commit comments

Comments
 (0)