Skip to content

Commit 747b4e2

Browse files
authored
test: Fix test_sampling.py on Spark (#2042)
<!-- .github/pull_request_template.md --> ## 📌 Description Current PR fixes `test_sampling.py::test_softmax` on Spark by inserting a `torch.cuda.synchronize()` before calling the softmax function. tl; dr why it works: PDL is enabled in these tests. Investigation shows that when PDL is enabled, `logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` that prepares the inputs overlaps with the `probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr)` function itself. Hence, we need to ensure that the input preparation is complete before running the softmax function to get the correct output. #### Observations `test_sampling.py::test_softmax` fails on select cases Spark. Example output ``` # pytest tests/utils/test_sampling.py::test_softmax =================================================================================================================================================== test session starts =================================================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 324 items ... ================================================================================================================================================= short test summary info ================================================================================================================================================= FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=1)-128256-989] - AssertionError: assert False FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution(std=5)-128256-989] - AssertionError: assert False FAILED tests/utils/test_sampling.py::test_softmax[True-True-1.0-gumbel_distribution(beta=0.1)-128256-989] - AssertionError: assert False ======================================================================================================================================== 3 failed, 321 passed, 1 warning in 10.33s ``` Observations from debugging: * When outputs are printed, rows containing all `nan`s are produced in the output of `probs = flashinfer.sampling.softmax(logits)` * Surprisingly, the test passes with `CUDA_LAUNCH_BLOCKING=1 pytest tests/utils/test_sampling.py::test_softmax` * `compute-sanitizer` does not detect any IMAs * Running only a failed test results in a pass: ``` $ pytest tests/utils/test_sampling.py::test_softmax[True-True-1.0-normal_distribution\(std=1\)-128256-989] ... 1 passed, 1 warning in 0.80s ``` Towards a fix: * I empirically find that the test passes: * when the reference `torch.softmax()` is called before `flashinfer.sampling.softmax()` (currently reference is called after) * when pdl is disabled in [line 67](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/utils/test_sampling.py#L67) with `probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr,enable_pdf=False)` * when `torch.cuda.synchronize()` is inserted in the line 64 as in this PR. ``` if neg_inf_input: # assign random logits to -inf num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item() inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf] logits.view(-1).index_fill_(0, inf_idx, float("-inf")) torch.cuda.synchronize() ## This fixes the issue for some reason! if temperature_arr: temperature_arr = torch.full((batch_size,), temperature, device="cuda:0") probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr) logits_scaled = logits / temperature_arr.unsqueeze(-1) ``` but **does not fix the issue if I place the synchronization any earlier** An nsys profile shows that surprisingly the `logits.view(-1).index_fill_(0, inf_idx, float("-inf"))` and `flashinfer.sampling.softmax(logits, temperature=temperature_arr)` can overlap execution when pdl is enabled. <img width="1243" height="640" alt="Screenshot 2025-11-04 at 5 49 50 PM" src="https://github.com/user-attachments/assets/950ab8ab-0843-49c8-8411-ff81c00c34a6" /> This means that the softmax kernel is launching before inputs are done being prepared when `neg_inf_input=True`. Hence, placing a `torch.cuda.synchronize()` after the fill or disabling pdl can solve the issue. With the current PR, the nsys timeline changes to: <img width="1240" height="643" alt="Screenshot 2025-11-04 at 5 51 32 PM" src="https://github.com/user-attachments/assets/aae63a88-d7cd-4661-8476-6d8c581879b2" /> and the unit test passes. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Bug Fixes** * Improved synchronization of concurrent operations to ensure proper execution order and prevent potential timing-related issues. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 9721ff7 commit 747b4e2

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tests/utils/test_sampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def test_softmax(
6161
num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item()
6262
inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf]
6363
logits.view(-1).index_fill_(0, inf_idx, float("-inf"))
64+
torch.cuda.synchronize() # wait for the index_fill_ to finish because it can overlap with the softmax kernel
6465

6566
if temperature_arr:
6667
temperature_arr = torch.full((batch_size,), temperature, device="cuda:0")

0 commit comments

Comments
 (0)