Skip to content

Commit 96e73b8

Browse files
fix: fix test_trtllm_gen_attention when max_seq_len < page_size (#2076)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Adjusted attention test calculations for K/V cache sizing to use per-sequence page allocation before scaling to the batch, improving alignment with expected memory allocation. * This refines test expectations around cache sizing without changing validation logic, reducing false positives in memory-related test scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jiying Dong <[email protected]>
1 parent eccbdde commit 96e73b8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/attention/test_trtllm_gen_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def create_kv_cache(
102102
):
103103
# Create separate K and V caches
104104
max_seq_len = torch.max(seq_lens).item()
105-
num_tokens = max_seq_len * batch_size
106-
num_pages = (num_tokens + page_size - 1) // page_size
105+
num_pages_per_seq = (max_seq_len + page_size - 1) // page_size
106+
num_pages = num_pages_per_seq * batch_size
107107
ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype]
108108
if kv_dtype != "fp8": # for fp8, create with high precision to generate scale.
109109
assert kv_dtype == ref_kv_dtype, (

0 commit comments

Comments
 (0)