Skip to content

Commit acc78ae

Browse files
authored
[Bugfix] Fix interns1-vit qk norm code path (#27480)
Signed-off-by: Isotr0py <[email protected]>
1 parent 0f67d4d commit acc78ae

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

vllm/model_executor/models/interns1_vit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,16 +217,15 @@ def __init__(
217217
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
218218

219219
def forward(self, x: torch.Tensor) -> torch.Tensor:
220-
B, N, C = x.shape
220+
"""x shape: (B, N, C)"""
221221

222222
q = self.q_proj(x)
223223
k = self.k_proj(x)
224224
v = self.v_proj(x)
225225

226226
if self.qk_normalization:
227-
B_, N_, H_, D_ = q.shape
228-
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
229-
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
227+
q = self.q_norm(q)
228+
k = self.k_norm(k)
230229

231230
# Use unified MultiHeadAttention with automatic backend selection
232231
x = self.attn(q, k, v)

0 commit comments

Comments
 (0)