Skip to content

[cuda graphs] Support variable batch size with cuda graphs #2046

@AgrawalAmey

Description

@AgrawalAmey

We were analyzing memory usage/wastage in Vajra to maximize kv cache size. During this one thing that popped out was the flashinfer wrapper. Here is the issue: since, using cuda graphs with flashinfer mandates fixing the batch size, we were creating separate wrappers for each batch size. To minimize memory waste, we were already reusing the same workspace buffers and metadata buffers. However, we noticed that each wrapper was still consuming 72 MB of memory. When running with a large max batch size, this leads to a significant increase in memory overhead.

To work around this problem, I tried to use one wrapper where I warm it up with max batch size, and subsequently, i just pad the indptr by repeating the last valid entry. This approach resolves the memory issue, but results in significant increase in latency. On Qwen3-4B on H100-NVL, the TBT latency goes up with 5.5ms (with max batch size = 8, actual batch size = 6) to 6.5ms (with max batch size = 128, actual batch size = 6).

I am exploring some more workarounds, but native support for dynamic batch size would be really helpful.

Flashinfer version: 0.5.1
Cuda version: cuda_12.9.r12.9
Torch version: 2.8
Device: H100-NVL

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions