Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions janus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from queue import Empty as SyncQueueEmpty
from queue import Full as SyncQueueFull
from time import monotonic
from typing import Callable, Generic, Optional, Protocol, TypeVar
from typing import Any, Callable, Generic, Optional, Protocol, TypeAlias, TypeVar

if sys.version_info >= (3, 13):
from asyncio import QueueShutDown as AsyncQueueShutDown
Expand Down Expand Up @@ -41,6 +41,19 @@ class ShutDown(Exception):
"BaseQueue",
)

_contra_T = TypeVar("_contra_T", contravariant=True)

class SupportsLT(Protocol[_contra_T]):
def __lt__(self, other: _contra_T): ...


class SupportsGT(Protocol[_contra_T]):
def __gt__(self, other: _contra_T): ...


RichComparable: TypeAlias = SupportsGT[Any] | SupportsLT[Any]

RichComparableT = TypeVar("RichComparableT", bound=RichComparable)

T = TypeVar("T")
OptFloat = Optional[float]
Expand Down Expand Up @@ -696,7 +709,7 @@ def shutdown(self, immediate: bool = False) -> None:
self._parent.shutdown(immediate)


class PriorityQueue(Queue[T]):
class PriorityQueue(Queue[RichComparableT]):
"""Variant of Queue that retrieves open entries in priority order
(lowest first).

Expand All @@ -705,15 +718,15 @@ class PriorityQueue(Queue[T]):
"""

def _init(self, maxsize: int) -> None:
self._heap_queue: list[T] = []
self._heap_queue: list[RichComparableT] = []

def _qsize(self) -> int:
return len(self._heap_queue)

def _put(self, item: T) -> None:
def _put(self, item: RichComparableT) -> None:
heappush(self._heap_queue, item)

def _get(self) -> T:
def _get(self) -> RichComparableT:
return heappop(self._heap_queue)


Expand Down