-
Notifications
You must be signed in to change notification settings - Fork 573
Description
We were analyzing memory usage/wastage in Vajra to maximize kv cache size. During this one thing that popped out was the flashinfer wrapper. Here is the issue: since, using cuda graphs with flashinfer mandates fixing the batch size, we were creating separate wrappers for each batch size. To minimize memory waste, we were already reusing the same workspace buffers and metadata buffers. However, we noticed that each wrapper was still consuming 72 MB of memory. When running with a large max batch size, this leads to a significant increase in memory overhead.
To work around this problem, I tried to use one wrapper where I warm it up with max batch size, and subsequently, i just pad the indptr by repeating the last valid entry. This approach resolves the memory issue, but results in significant increase in latency. On Qwen3-4B on H100-NVL, the TBT latency goes up with 5.5ms (with max batch size = 8, actual batch size = 6) to 6.5ms (with max batch size = 128, actual batch size = 6).
I am exploring some more workarounds, but native support for dynamic batch size would be really helpful.
Flashinfer version: 0.5.1
Cuda version: cuda_12.9.r12.9
Torch version: 2.8
Device: H100-NVL