Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions tests/test_find_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,69 +22,109 @@
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations

from collections.abc import Callable, Iterable, Iterator
from typing import Literal, TypeVar
from typing_extensions import TypeIs

import pytest
from discord.utils import find

T = TypeVar("T")


def is_even(x):
def is_even(x: int) -> bool:
return x % 2 == 0


def always_true(_: object) -> bool:
return True


def greater_than_3(x: int) -> bool:
return x > 3


def equals_1(x: int) -> TypeIs[Literal[1]]:
return x == 1


def equals_2(x: int) -> TypeIs[Literal[2]]:
return x == 2


def equals_b(c: str) -> TypeIs[Literal["b"]]:
return c == "b"


def equals_30(x: int) -> TypeIs[Literal[30]]:
return x == 30


def is_none_pred(x: object) -> bool:
return x is None


@pytest.mark.parametrize(
("seq", "predicate", "expected"),
[
([], lambda x: True, None),
([1, 2, 3], lambda x: x > 3, None),
([1, 2, 3], lambda x: x == 1, 1),
([1, 2, 3], lambda x: x == 2, 2),
("abc", lambda c: c == "b", "b"),
((10, 20, 30), lambda x: x == 30, 30),
([None, False, 0], lambda x: x is None, None),
([], always_true, None),
([1, 2, 3], greater_than_3, None),
([1, 2, 3], equals_1, 1),
([1, 2, 3], equals_2, 2),
("abc", equals_b, "b"),
((10, 20, 30), equals_30, 30),
([None, False, 0], is_none_pred, None),
([1, 2, 3, 4], is_even, 2),
],
)
def test_find_basic_parametrized(seq, predicate, expected):
def test_find_basic_parametrized(
seq: Iterable[T],
predicate: Callable[[T], object],
expected: T | None,
) -> None:
result = find(predicate, seq)
if expected is None:
assert result is None
else:
assert result == expected


def test_find_with_truthy_non_boolean_predicate():
seq = [2, 4, 5, 6]
def test_find_with_truthy_non_boolean_predicate() -> None:
seq: list[int] = [2, 4, 5, 6]
result = find(lambda x: x % 2, seq)
assert result == 5


def test_find_on_generator_and_stop_early():
def bad_gen():
def test_find_on_generator_and_stop_early() -> None:
def bad_gen() -> Iterator[str]:
yield "first"
raise RuntimeError("should not be reached")

assert find(lambda x: x == "first", bad_gen()) == "first"


def test_find_does_not_evaluate_rest():
calls = []
def test_find_does_not_evaluate_rest() -> None:
calls: list[str] = []

def predicate(x):
def predicate(x: str) -> bool:
calls.append(x)
return x == "stop"

seq = ["go", "stop", "later"]
seq: list[str] = ["go", "stop", "later"]
result = find(predicate, seq)
assert result == "stop"
assert calls == ["go", "stop"]


def test_find_with_set_returns_first_iterated_element():
data = {"a", "b", "c"}
def test_find_with_set_returns_first_iterated_element() -> None:
data: set[str] = {"a", "b", "c"}
result = find(lambda x: x in data, data)
assert result in data


def test_find_none_predicate():
seq = [42, 43, 44]
def test_find_none_predicate() -> None:
seq: list[int] = [42, 43, 44]
result = find(lambda x: True, seq)
assert result == 42
25 changes: 20 additions & 5 deletions tests/test_format_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@
import datetime
import random
import pytest
from discord.utils import format_dt
from discord.utils.public import format_dt, TimestampStyle

# Fix seed so that time tests are reproducible
random.seed(42)

ALL_STYLES = ["t", "T", "d", "D", "f", "F", "R", None]
ALL_STYLES = [
"t",
"T",
"d",
"D",
"f",
"F",
"R",
None,
]

DATETIME_CASES = [
(datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), 0),
Expand All @@ -41,7 +50,7 @@
]


def random_time():
def random_time() -> datetime.time:
return datetime.time(
random.randint(0, 23),
random.randint(0, 59),
Expand All @@ -51,7 +60,11 @@ def random_time():

@pytest.mark.parametrize(("dt", "expected_ts"), DATETIME_CASES)
@pytest.mark.parametrize("style", ALL_STYLES)
def test_format_dt_formats_datetime(dt, expected_ts, style):
def test_format_dt_formats_datetime(
dt: datetime.datetime,
expected_ts: int,
style: TimestampStyle | None,
) -> None:
if style is None:
expected = f"<t:{expected_ts}>"
else:
Expand All @@ -61,7 +74,9 @@ def test_format_dt_formats_datetime(dt, expected_ts, style):


@pytest.mark.parametrize("style", ALL_STYLES)
def test_format_dt_formats_time_equivalence(style):
def test_format_dt_formats_time_equivalence(
style: TimestampStyle | None,
) -> None:
tm = random_time()
today = datetime.datetime.now().date()
result_time = format_dt(tm, style=style)
Expand Down
10 changes: 0 additions & 10 deletions tests/test_markdown_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,10 @@
"""

from discord.utils import (
oauth_url,
snowflake_time,
find,
get_or_fetch,
utcnow,
remove_markdown,
escape_markdown,
escape_mentions,
raw_mentions,
raw_channel_mentions,
raw_role_mentions,
format_dt,
generate_snowflake,
basic_autocomplete,
)


Expand Down
20 changes: 12 additions & 8 deletions tests/test_snowflake_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import datetime
import pytest

from discord.utils import generate_snowflake, snowflake_time, DISCORD_EPOCH
from discord.utils import (
DISCORD_EPOCH,
generate_snowflake,
snowflake_time,
)

UTC = datetime.timezone.utc

Expand All @@ -39,40 +43,40 @@


@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
def test_generate_snowflake_realistic(dt, expected_ms):
def test_generate_snowflake_realistic(dt: datetime.datetime, expected_ms: int) -> None:
sf = generate_snowflake(dt, mode="realistic")
assert (sf >> 22) == expected_ms
assert (sf & ((1 << 22) - 1)) == 0x3FFFFF


@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
def test_generate_snowflake_boundary_low(dt, expected_ms):
def test_generate_snowflake_boundary_low(dt: datetime.datetime, expected_ms: int) -> None:
sf = generate_snowflake(dt, mode="boundary", high=False)
assert (sf >> 22) == expected_ms
assert (sf & ((1 << 22) - 1)) == 0


@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
def test_generate_snowflake_boundary_high(dt, expected_ms):
def test_generate_snowflake_boundary_high(dt: datetime.datetime, expected_ms: int) -> None:
sf = generate_snowflake(dt, mode="boundary", high=True)
assert (sf >> 22) == expected_ms
assert (sf & ((1 << 22) - 1)) == (2**22 - 1)


@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
def test_snowflake_time_roundtrip_boundary(dt, expected_ms):
def test_snowflake_time_roundtrip_boundary(dt: datetime.datetime, _expected_ms: int) -> None:
sf_low = generate_snowflake(dt, mode="boundary", high=False)
sf_high = generate_snowflake(dt, mode="boundary", high=True)
assert snowflake_time(sf_low) == dt
assert snowflake_time(sf_high) == dt


@pytest.mark.parametrize(("dt", "expected_ms"), DATETIME_CASES)
def test_snowflake_time_roundtrip_realistic(dt, expected_ms):
def test_snowflake_time_roundtrip_realistic(dt: datetime.datetime, _expected_ms: int) -> None:
sf = generate_snowflake(dt, mode="realistic")
assert snowflake_time(sf) == dt


def test_generate_snowflake_invalid_mode():
def test_generate_snowflake_invalid_mode() -> None:
with pytest.raises(ValueError, match="Invalid mode 'nope'. Must be 'realistic' or 'boundary'"):
generate_snowflake(datetime.datetime.now(tz=UTC), mode="nope")
generate_snowflake(datetime.datetime.now(tz=UTC), mode="nope") # pyright: ignore[reportArgumentType]
Loading