1+ import os
12import random
23
34import pytest
45import torch
56
67
8+ def pytest_configure (config ):
9+ # register an additional marker (see pytest_collection_modifyitems)
10+ config .addinivalue_line (
11+ "markers" , "needs_cuda: mark for tests that rely on a CUDA device"
12+ )
13+
14+
15+ def pytest_collection_modifyitems (items ):
16+ # This hook is called by pytest after it has collected the tests (google its
17+ # name to check out its doc!). We can ignore some tests as we see fit here,
18+ # or add marks, such as a skip mark.
19+
20+ out_items = []
21+ for item in items :
22+ # The needs_cuda mark will exist if the test was explicitly decorated
23+ # with the @needs_cuda decorator. It will also exist if it was
24+ # parametrized with a parameter that has the mark: for example if a test
25+ # is parametrized with
26+ # @pytest.mark.parametrize('device', cpu_and_cuda())
27+ # the "instances" of the tests where device == 'cuda' will have the
28+ # 'needs_cuda' mark, and the ones with device == 'cpu' won't have the
29+ # mark.
30+ needs_cuda = item .get_closest_marker ("needs_cuda" ) is not None
31+
32+ if (
33+ needs_cuda
34+ and not torch .cuda .is_available ()
35+ and os .environ .get ("FAIL_WITHOUT_CUDA" ) is None
36+ ):
37+ # We skip CUDA tests on non-CUDA machines, but only if the
38+ # FAIL_WITHOUT_CUDA env var wasn't set. If it's set, the test will
39+ # typically fail with a "Unsupported device: cuda" error. This is
40+ # normal and desirable: this env var is set on CI jobs that are
41+ # supposed to run the CUDA tests, so if CUDA isn't available on
42+ # those for whatever reason, we need to know.
43+ item .add_marker (pytest .mark .skip (reason = "CUDA not available." ))
44+
45+ out_items .append (item )
46+
47+ items [:] = out_items
48+
49+
750@pytest .fixture (autouse = True )
851def prevent_leaking_rng ():
952 # Prevent each test from leaking the rng to all other test when they call
@@ -20,10 +63,3 @@ def prevent_leaking_rng():
2063 random .setstate (builtin_rng_state )
2164 if torch .cuda .is_available ():
2265 torch .cuda .set_rng_state (cuda_rng_state )
23-
24-
25- def pytest_configure (config ):
26- # register an additional marker (see pytest_collection_modifyitems)
27- config .addinivalue_line (
28- "markers" , "needs_cuda: mark for tests that rely on a CUDA device"
29- )
0 commit comments