Skip to content

Commit da01b1b

Browse files
authored
test: Enable xfailed trtllm decode long seqlen tests and update microbenchmark (#2018)
<!-- .github/pull_request_template.md --> ## 📌 Description [tests/attention/test_trtllm_gen_attention.py](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/tests/attention/test_trtllm_gen_attention.py#L1021-L1076) was failing and therefore marked xfail. PR #2002 fixed the underlying root cause. Current PR thus removed the `xfail` marker so that these long seqlen cases could be fixed moving forward. Additionally, PR #2002 revealed a bug in the microbenchmark script where [trtllm_batch_decode_with_kv_cache](https://github.com/flashinfer-ai/flashinfer/blob/v0.5.0rc2/flashinfer/decode.py#L2082-L2083) explicitly requires the workspace to be zeroed before first use: ``` workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace ``` while the microbenchmark code does not zero out, causing undefined behavior such as IMAs that depend on the ordering of backends tested. Current PR fixes the issue by explicitly calling `workspace_buffer.zero_()` between testing different backends. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved stability of performance benchmarks by properly resetting workspace buffer between backend invocations. * **Tests** * Enabled previously skipped test for long sequence length handling. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5854494 commit da01b1b

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

benchmarks/routines/attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ def run_backend_wrapper(backend):
508508
has_reference_output = False
509509
# Iterate over each backend:
510510
for cur_backend in backends:
511+
# Clear workspace buffer to prevent unexpected interactions between backends.
512+
workspace_buffer.zero_()
511513
if run_refcheck:
512514
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
513515
if cur_backend == "fa2":
@@ -975,6 +977,8 @@ def run_backend_wrapper(backend):
975977
has_reference_output = False
976978
# Iterate over each backend:
977979
for cur_backend in backends:
980+
# Clear workspace buffer to prevent unexpected interactions between backends.
981+
workspace_buffer.zero_()
978982
if run_refcheck:
979983
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
980984
if cur_backend == "fa2":
@@ -1427,6 +1431,8 @@ def run_backend_wrapper(backend):
14271431
has_reference_output = False
14281432
# Iterate over each backend:
14291433
for cur_backend in backends:
1434+
# Clear workspace buffer to prevent unexpected interactions between backends.
1435+
workspace_buffer.zero_()
14301436
if run_refcheck:
14311437
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
14321438
if cur_backend == "fa2":
@@ -1822,6 +1828,8 @@ def run_backend_wrapper(backend):
18221828
has_reference_output = False
18231829
# Iterate over each backend:
18241830
for cur_backend in backends:
1831+
# Clear workspace buffer to prevent unexpected interactions between backends.
1832+
workspace_buffer.zero_()
18251833
if run_refcheck:
18261834
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
18271835
if cur_backend == "fa2":

tests/attention/test_trtllm_gen_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,6 @@ def test_trtllm_batch_decode_long_sequence_length(
11331133
head_dim,
11341134
):
11351135
# Small number of test cases for long sequence length
1136-
pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length")
11371136
_test_trtllm_batch_decode(
11381137
"trtllm-gen",
11391138
kv_layout,

0 commit comments

Comments
 (0)