Skip to content

Commit d6704dd

Browse files
rogeryounghxuebi
andauthored
Fix MiniMax-M2 rmsnorm precision and remove useless code (#27627)
Signed-off-by: xuebi <[email protected]> Co-authored-by: xuebi <[email protected]>
1 parent ecca3fe commit d6704dd

File tree

2 files changed

+1
-19
lines changed

2 files changed

+1
-19
lines changed

vllm/model_executor/layers/mamba/linear_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _forward(
7777
if self.tp_world > 1:
7878
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
7979
x = x * torch.rsqrt(variance + self.variance_epsilon)
80-
x = x.to(orig_dtype) * self.weight
80+
x = (x * self.weight).to(orig_dtype)
8181
return x
8282

8383
def forward(

vllm/model_executor/models/minimax_m2.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -263,23 +263,6 @@ def __init__(
263263
# with the layer's index.
264264
layer_idx = int(prefix.split(sep=".")[-1])
265265

266-
# TODO: support MTP
267-
attn_window_size = getattr(config, "attn_window_size", None)
268-
if attn_window_size is not None:
269-
if isinstance(attn_window_size, list):
270-
attn_window_size = attn_window_size[layer_idx]
271-
elif isinstance(attn_window_size, int):
272-
attn_window_size = attn_window_size
273-
else:
274-
raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
275-
attn_window_size = None if attn_window_size <= 0 else attn_window_size
276-
277-
# different rope theta for full layer and swa layer
278-
swa_rope_theta = getattr(config, "swa_rope_theta", -1)
279-
# default to full rope theta
280-
swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
281-
rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta
282-
283266
self.layer_idx = layer_idx
284267
self.self_attn = MiniMaxM2Attention(
285268
hidden_size=self.hidden_size,
@@ -288,7 +271,6 @@ def __init__(
288271
rotary_dim=config.rotary_dim,
289272
rope_theta=rope_theta,
290273
rope_scaling=rope_scaling,
291-
attn_window_size=attn_window_size,
292274
max_position_embeddings=max_position_embeddings,
293275
rms_norm_eps=config.rms_norm_eps,
294276
qkv_bias=getattr(config, "attention_bias", False),

0 commit comments

Comments
 (0)