Skip to content

Commit 9bc5bd5

Browse files
bugfix: fix failed unittest test_green_ctx and test_jit_example on spark (sm_121) (#1951)
<!-- .github/pull_request_template.md --> ## 📌 Description There are three failed unittests on spark (sm_121): * tests/utils/test_green_ctx.py * tests/utils/test_jit_example.py * tests/utils/test_sampling.py First one is because spark has small number of SMs (48) and we don't have a guard on green context splitting. Second one is an unknown issue (logits don't match with reference) and probably related to barriers on sm_121, xfail now and will fix later. The last one will be fixed by another PR from @bkryu , this PR fixes the first two issues. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Tests now pre-check GPU resources and auto-skip with informative messages including available and requested SM counts to avoid spurious failures. * Added a conditional xfail for GPUs with compute capability 12.1 to avoid false negatives on that hardware. * Tightened a sampling test by adding a relative tolerance for more robust numerical validation. * **Bug Fixes** * Improved runtime error handling to surface clearer guidance when GPU SM resources are insufficient. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent e1c1e2a commit 9bc5bd5

File tree

3 files changed

+197
-79
lines changed

3 files changed

+197
-79
lines changed

flashinfer/green_ctx.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,27 @@ def split_device_green_ctx(
170170
RuntimeError: when requested SM allocation exceeds device capacity:
171171
``num_groups * rounded_min_count > total_device_sms``
172172
"""
173-
cu_dev = get_cudevice(dev)
174-
resource = get_device_resource(cu_dev)
175-
results, remaining = split_resource(resource, num_groups, min_count)
176-
resources = results + [remaining]
177-
streams = create_green_ctx_streams(cu_dev, resources)
178-
return streams, resources
173+
try:
174+
cu_dev = get_cudevice(dev)
175+
resource = get_device_resource(cu_dev)
176+
results, remaining = split_resource(resource, num_groups, min_count)
177+
resources = results + [remaining]
178+
streams = create_green_ctx_streams(cu_dev, resources)
179+
return streams, resources
180+
except RuntimeError as e:
181+
if (
182+
"CUDA error code=914" in str(e)
183+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
184+
or "CUDA error code=915" in str(e)
185+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
186+
):
187+
raise RuntimeError(
188+
f"{e}\n"
189+
f"Failed to split device into {num_groups} groups with min_count={min_count}. "
190+
f"This is likely due to insufficient number of SMs available on the device. "
191+
f"Please reduce the number of groups or the minimum SM count per group."
192+
) from e
193+
raise
179194

180195

181196
def split_device_green_ctx_by_sm_count(
@@ -241,21 +256,40 @@ def split_device_green_ctx_by_sm_count(
241256
See `CUDA Green Contexts <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html>`_
242257
for more details.
243258
"""
244-
cu_dev = get_cudevice(dev)
245-
resource = get_device_resource(cu_dev)
259+
try:
260+
cu_dev = get_cudevice(dev)
261+
resource = get_device_resource(cu_dev)
262+
263+
# Round sm counts to meet the alignment and granularity requirements
264+
rounded_sm_counts = []
265+
for sm_count in sm_counts:
266+
min_sm_count, sm_alignment = get_sm_count_constraint(
267+
*get_compute_capability(dev)
268+
)
269+
if sm_count <= 0:
270+
raise ValueError(f"SM count must be positive, got {sm_count}")
271+
rounded_sm_counts.append(
272+
round_up(max(sm_count, min_sm_count), sm_alignment)
273+
)
246274

247-
# Round sm counts to meet the alignment and granularity requirements
248-
rounded_sm_counts = []
249-
for sm_count in sm_counts:
250-
min_sm_count, sm_alignment = get_sm_count_constraint(
251-
*get_compute_capability(dev)
275+
# Split the device into multiple green contexts
276+
results, remaining = split_resource_by_sm_count(
277+
cu_dev, resource, rounded_sm_counts
252278
)
253-
if sm_count <= 0:
254-
raise ValueError(f"SM count must be positive, got {sm_count}")
255-
rounded_sm_counts.append(round_up(max(sm_count, min_sm_count), sm_alignment))
256-
257-
# Split the device into multiple green contexts
258-
results, remaining = split_resource_by_sm_count(cu_dev, resource, rounded_sm_counts)
259-
resources = results + [remaining]
260-
streams = create_green_ctx_streams(cu_dev, resources)
261-
return streams, resources
279+
resources = results + [remaining]
280+
streams = create_green_ctx_streams(cu_dev, resources)
281+
return streams, resources
282+
except RuntimeError as e:
283+
if (
284+
"CUDA error code=914" in str(e)
285+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
286+
or "CUDA error code=915" in str(e)
287+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
288+
):
289+
raise RuntimeError(
290+
f"{e}\n"
291+
f"Failed to split device with SM counts {sm_counts} (rounded to {rounded_sm_counts}). "
292+
f"This is likely due to insufficient number of SMs available on the device. "
293+
f"Please reduce the requested SM counts or use fewer partitions."
294+
) from e
295+
raise

tests/utils/test_green_ctx.py

Lines changed: 136 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,30 @@ def test_green_ctx_creation(
1212
num_groups: int,
1313
min_count: int,
1414
):
15-
streams, resources = green_ctx.split_device_green_ctx(
16-
torch.device(device), num_groups, min_count
17-
)
15+
try:
16+
streams, resources = green_ctx.split_device_green_ctx(
17+
torch.device(device), num_groups, min_count
18+
)
1819

19-
assert len(resources) == num_groups + 1
20-
for resource in resources[:-1]:
21-
sm_count = resource.sm.smCount
22-
assert sm_count >= min_count
20+
assert len(resources) == num_groups + 1
21+
for resource in resources[:-1]:
22+
sm_count = resource.sm.smCount
23+
assert sm_count >= min_count
24+
except RuntimeError as e:
25+
if (
26+
"CUDA error code=914" in str(e)
27+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
28+
or "CUDA error code=915" in str(e)
29+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
30+
):
31+
# Get total SM count on the device
32+
cu_dev = green_ctx.get_cudevice(torch.device(device))
33+
device_resource = green_ctx.get_device_resource(cu_dev)
34+
total_sms = device_resource.sm.smCount
35+
pytest.skip(
36+
f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}"
37+
)
38+
raise
2339

2440

2541
@pytest.mark.parametrize("device", ["cuda:0"])
@@ -30,19 +46,35 @@ def test_green_ctx_kernel_execution(
3046
num_groups: int,
3147
min_count: int,
3248
):
33-
streams, resources = green_ctx.split_device_green_ctx(
34-
torch.device(device), num_groups, min_count
35-
)
36-
num_partitions = num_groups + 1
37-
assert len(streams) == num_partitions
38-
assert len(resources) == num_partitions
39-
40-
for stream in streams:
41-
with torch.cuda.stream(stream):
42-
x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
43-
y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
44-
z = x @ y
45-
print(z.shape)
49+
try:
50+
streams, resources = green_ctx.split_device_green_ctx(
51+
torch.device(device), num_groups, min_count
52+
)
53+
num_partitions = num_groups + 1
54+
assert len(streams) == num_partitions
55+
assert len(resources) == num_partitions
56+
57+
for stream in streams:
58+
with torch.cuda.stream(stream):
59+
x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
60+
y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
61+
z = x @ y
62+
print(z.shape)
63+
except RuntimeError as e:
64+
if (
65+
"CUDA error code=914" in str(e)
66+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
67+
or "CUDA error code=915" in str(e)
68+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
69+
):
70+
# Get total SM count on the device
71+
cu_dev = green_ctx.get_cudevice(torch.device(device))
72+
device_resource = green_ctx.get_device_resource(cu_dev)
73+
total_sms = device_resource.sm.smCount
74+
pytest.skip(
75+
f"Insufficient SMs on device. Total SMs available: {total_sms}, requested: num_groups={num_groups}, min_count={min_count}"
76+
)
77+
raise
4678

4779

4880
@pytest.mark.parametrize("device", ["cuda:0"])
@@ -59,17 +91,33 @@ def test_split_device_green_ctx_by_sm_count_creation(
5991
device: str,
6092
sm_counts: list,
6193
):
62-
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
63-
torch.device(device), sm_counts
64-
)
65-
num_partitions = len(sm_counts) + 1
66-
assert len(resources) == num_partitions
67-
assert len(streams) == num_partitions
68-
69-
# Check that each partition has the expected SM count
70-
for i, expected_sm_count in enumerate(sm_counts):
71-
actual_sm_count = resources[i].sm.smCount
72-
assert actual_sm_count >= expected_sm_count
94+
try:
95+
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
96+
torch.device(device), sm_counts
97+
)
98+
num_partitions = len(sm_counts) + 1
99+
assert len(resources) == num_partitions
100+
assert len(streams) == num_partitions
101+
102+
# Check that each partition has the expected SM count
103+
for i, expected_sm_count in enumerate(sm_counts):
104+
actual_sm_count = resources[i].sm.smCount
105+
assert actual_sm_count >= expected_sm_count
106+
except RuntimeError as e:
107+
if (
108+
"CUDA error code=914" in str(e)
109+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
110+
or "CUDA error code=915" in str(e)
111+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
112+
):
113+
# Get total SM count on the device
114+
cu_dev = green_ctx.get_cudevice(torch.device(device))
115+
device_resource = green_ctx.get_device_resource(cu_dev)
116+
total_sms = device_resource.sm.smCount
117+
pytest.skip(
118+
f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}"
119+
)
120+
raise
73121

74122

75123
@pytest.mark.parametrize("device", ["cuda:0"])
@@ -85,19 +133,35 @@ def test_split_device_green_ctx_by_sm_count_kernel_execution(
85133
device: str,
86134
sm_counts: list,
87135
):
88-
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
89-
torch.device(device), sm_counts
90-
)
91-
num_partitions = len(sm_counts) + 1
92-
assert len(streams) == num_partitions
93-
assert len(resources) == num_partitions
94-
95-
for i, stream in enumerate(streams):
96-
with torch.cuda.stream(stream):
97-
x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
98-
y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
99-
z = x @ y
100-
print(f"Partition {i}: {z.shape}")
136+
try:
137+
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
138+
torch.device(device), sm_counts
139+
)
140+
num_partitions = len(sm_counts) + 1
141+
assert len(streams) == num_partitions
142+
assert len(resources) == num_partitions
143+
144+
for i, stream in enumerate(streams):
145+
with torch.cuda.stream(stream):
146+
x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
147+
y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
148+
z = x @ y
149+
print(f"Partition {i}: {z.shape}")
150+
except RuntimeError as e:
151+
if (
152+
"CUDA error code=914" in str(e)
153+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
154+
or "CUDA error code=915" in str(e)
155+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
156+
):
157+
# Get total SM count on the device
158+
cu_dev = green_ctx.get_cudevice(torch.device(device))
159+
device_resource = green_ctx.get_device_resource(cu_dev)
160+
total_sms = device_resource.sm.smCount
161+
pytest.skip(
162+
f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}"
163+
)
164+
raise
101165

102166

103167
@pytest.mark.parametrize("device", ["cuda:0"])
@@ -113,16 +177,32 @@ def test_split_device_green_ctx_by_sm_count_alignment(
113177
device: str,
114178
sm_counts: list,
115179
):
116-
_, resources = green_ctx.split_device_green_ctx_by_sm_count(
117-
torch.device(device), sm_counts
118-
)
119-
120-
for resource in resources[:-1]: # Exclude remaining SMs
121-
sm_count = resource.sm.smCount
122-
assert sm_count > 0
123-
124-
min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
125-
*green_ctx.get_compute_capability(torch.device(device))
180+
try:
181+
_, resources = green_ctx.split_device_green_ctx_by_sm_count(
182+
torch.device(device), sm_counts
126183
)
127-
assert sm_count >= min_sm_count
128-
assert sm_count % sm_alignment == 0
184+
185+
for resource in resources[:-1]: # Exclude remaining SMs
186+
sm_count = resource.sm.smCount
187+
assert sm_count > 0
188+
189+
min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
190+
*green_ctx.get_compute_capability(torch.device(device))
191+
)
192+
assert sm_count >= min_sm_count
193+
assert sm_count % sm_alignment == 0
194+
except RuntimeError as e:
195+
if (
196+
"CUDA error code=914" in str(e)
197+
or "CUDA_ERROR_INVALID_RESOURCE_TYPE" in str(e)
198+
or "CUDA error code=915" in str(e)
199+
or "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(e)
200+
):
201+
# Get total SM count on the device
202+
cu_dev = green_ctx.get_cudevice(torch.device(device))
203+
device_resource = green_ctx.get_device_resource(cu_dev)
204+
total_sms = device_resource.sm.smCount
205+
pytest.skip(
206+
f"Insufficient SMs on device. Total SMs available: {total_sms}, requested SM counts: {sm_counts}"
207+
)
208+
raise

tests/utils/test_jit_example.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
gen_customize_single_prefill_module,
1212
)
1313
from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module
14-
from flashinfer.utils import MaskMode, is_sm90a_supported
14+
from flashinfer.utils import MaskMode, is_sm90a_supported, get_compute_capability
1515

1616

1717
def test_single_decode_mask():
@@ -166,6 +166,10 @@ def test_flash_sigmoid():
166166
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
167167

168168

169+
@pytest.mark.xfail(
170+
get_compute_capability(torch.device("cuda:0")) == (12, 1),
171+
reason="Numerical accuracy issue on SM 121 (Spark)",
172+
)
169173
def test_dump_logits():
170174
torch.manual_seed(42)
171175
variant_decl = r"""

0 commit comments

Comments
 (0)