In the mistral_attn_forward_* series function, the attribute kv_seq_len is dynamically added to self and used to accumulate the total key/value length across batches. However, since this attribute is stored on the module instance (self), it persists across different forward passes during inference.
This leads to a critical issue: in subsequent batches, the condition
if key_states.shape[-2] == kv_seq_len:
is no longer satisfied because kv_seq_len includes the accumulated value from previous batches. As a result, the following core pruning function is never executed again after the first batch:
self.kv_cluster.update_kv(...)
This causes the pruning logic to become completely non-functional beyond the first batch, silently bypassing the compression mechanism that kv_cluster.update_kv is meant to handle.
The root cause is that self.kv_seq_len—which should be a per-batch variable—is instead retained as a persistent instance attribute, leaking state across forward passes and breaking the intended behavior.