@@ -29,7 +29,7 @@ class Broadcast:
2929 def __init__ (self , url : str | None = None , * , backend : BroadcastBackend | None = None ) -> None :
3030 assert url or backend , "Either `url` or `backend` must be provided."
3131 self ._backend = backend or self ._create_backend (cast (str , url ))
32- self ._subscribers : dict [str , set [asyncio .Queue [Event | None ]]] = {}
32+ self ._subscribers : dict [str , set [asyncio .Queue [Event | BaseException | None ]]] = {}
3333
3434 def _create_backend (self , url : str ) -> BroadcastBackend :
3535 parsed_url = urlparse (url )
@@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
6969 async def connect (self ) -> None :
7070 await self ._backend .connect ()
7171 self ._listener_task = asyncio .create_task (self ._listener ())
72+ self ._listener_task .add_done_callback (self .drop )
73+
74+ def drop (self , task : asyncio .Task [None ]) -> None :
75+ exc = task .exception ()
76+ for queues in self ._subscribers .values ():
77+ for queue in queues :
78+ queue .put_nowait (exc )
7279
7380 async def disconnect (self ) -> None :
7481 if self ._listener_task .done ():
75- self ._listener_task .result ()
82+ exc = self ._listener_task .exception ()
83+ if exc is None :
84+ self ._listener_task .result ()
7685 else :
7786 self ._listener_task .cancel ()
7887 await self ._backend .disconnect ()
@@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None:
8897
8998 @asynccontextmanager
9099 async def subscribe (self , channel : str ) -> AsyncIterator [Subscriber ]:
91- queue : asyncio .Queue [Event | None ] = asyncio .Queue ()
100+ queue : asyncio .Queue [Event | BaseException | None ] = asyncio .Queue ()
92101
93102 try :
94103 if not self ._subscribers .get (channel ):
@@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
107116
108117
109118class Subscriber :
110- def __init__ (self , queue : asyncio .Queue [Event | None ]) -> None :
119+ def __init__ (self , queue : asyncio .Queue [Event | BaseException | None ]) -> None :
111120 self ._queue = queue
112121
113122 async def __aiter__ (self ) -> AsyncGenerator [Event | None , None ]:
@@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
119128
120129 async def get (self ) -> Event :
121130 item = await self ._queue .get ()
131+ if isinstance (item , BaseException ):
132+ raise item
122133 if item is None :
123134 raise Unsubscribed ()
124135 return item
0 commit comments