Commit 747b4e2
authored
test: Fix test_sampling.py on Spark (#2042)
<!-- .github/pull_request_template.md -->
## 📌 Description
Current PR fixes `test_sampling.py::test_softmax` on Spark by inserting
a `torch.cuda.synchronize()` before calling the softmax function.
tl; dr why it works: PDL is enabled in these tests. Investigation shows
that when PDL is enabled, `logits.view(-1).index_fill_(0, inf_idx,
float("-inf"))` that prepares the inputs overlaps with the `probs =
flashinfer.sampling.softmax(logits, temperature=temperature_arr)`
function itself. Hence, we need to ensure that the input preparation is
complete before running the softmax function to get the correct output.
#### Observations
`test_sampling.py::test_softmax` fails on select cases Spark. Example
output
```
# pytest tests/utils/test_sampling.py::test_softmax
=================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 324 items
...
================================================================================================================================================= short test summary info =================================================================================================================================================
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] - AssertionError: assert False
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=5)-128256-989] - AssertionError: assert False
FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-gumbel_distribution(beta=0.1)-128256-989] - AssertionError: assert False
======================================================================================================================================== 3 failed, 321 passed, 1 warning in 10.33s
```
Observations from debugging:
* When outputs are printed, rows containing all `nan`s are produced in
the output of `probs = flashinfer.sampling.softmax(logits)`
* Surprisingly, the test passes with `CUDA_LAUNCH_BLOCKING=1 pytest
tests/utils/test_sampling.py::test_softmax`
* `compute-sanitizer` does not detect any IMAs
* Running only a failed test results in a pass:
```
$ pytest tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution\(std=1\)-128256-989]
...
1 passed, 1 warning in 0.80s
```
Towards a fix:
* I empirically find that the test passes:
* when the reference `torch.softmax()` is called before
`flashinfer.sampling.softmax()` (currently reference is called after)
* when pdl is disabled in [line
67](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_sampling.py#L67)
with `probs = flashinfer.sampling.softmax(logits,
temperature=temperature_arr,enable_pdf=False)`
* when `torch.cuda.synchronize()` is inserted in the line 64 as in this
PR.
```
if neg_inf_input:
# assign random logits to -inf
num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item()
inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf]
logits.view(-1).index_fill_(0, inf_idx, float("-inf"))
torch.cuda.synchronize() ## This fixes the issue for some reason!
if temperature_arr:
temperature_arr = torch.full((batch_size,), temperature, device="cuda:0")
probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr)
logits_scaled = logits / temperature_arr.unsqueeze(-1)
```
but **does not fix the issue if I place the synchronization any
earlier**
An nsys profile shows that surprisingly the
`logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` and
`flashinfer.sampling.softmax(logits, temperature=temperature_arr)` can
overlap execution when pdl is enabled.
<img width="1243" height="640" alt="Screenshot 2025-11-04 at 5 49 50 PM"
src="https://github.com/user-attachments/assets/950ab8ab-0843-49c8-8411-ff81c00c34a6"
/>
This means that the softmax kernel is launching before inputs are done
being prepared when `neg_inf_input=True`. Hence, placing a
`torch.cuda.synchronize()` after the fill or disabling pdl can solve the
issue. With the current PR, the nsys timeline changes to:
<img width="1240" height="643" alt="Screenshot 2025-11-04 at 5 51 32 PM"
src="https://github.com/user-attachments/assets/aae63a88-d7cd-4661-8476-6d8c581879b2"
/>
and the unit test passes.
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **Bug Fixes**
* Improved synchronization of concurrent operations to ensure proper
execution order and prevent potential timing-related issues.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->1 parent 9721ff7 commit 747b4e2
1 file changed
+1
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
| |||
0 commit comments