Skip to content

BUG report: Persistent kv_seq_len breaks core pruning logic: kv_cluster.update_kv skipped after first batch #59

@vlev02

Description

@vlev02

In the mistral_attn_forward_* series function, the attribute kv_seq_len is dynamically added to self and used to accumulate the total key/value length across batches. However, since this attribute is stored on the module instance (self), it persists across different forward passes during inference.

This leads to a critical issue: in subsequent batches, the condition

if key_states.shape[-2] == kv_seq_len:

is no longer satisfied because kv_seq_len includes the accumulated value from previous batches. As a result, the following core pruning function is never executed again after the first batch:

self.kv_cluster.update_kv(...)

This causes the pruning logic to become completely non-functional beyond the first batch, silently bypassing the compression mechanism that kv_cluster.update_kv is meant to handle.

The root cause is that self.kv_seq_len—which should be a per-batch variable—is instead retained as a persistent instance attribute, leaking state across forward passes and breaking the intended behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions