Skip to content

Commit 1664357

Browse files
committed
Fix
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 45acc6b commit 1664357

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ def wrapper(*args, **kwargs):
155155
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
156156
num_tokens = hidden_states.size(0)
157157
M = min(num_tokens, CHUNK_SIZE)
158-
158+
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
159159
shrink_config, expand_config = self._get_lora_moe_configs(
160160
op_prefix="w13",
161161
num_loras=self.max_loras,
162-
rank=self.max_lora_rank,
162+
rank=max_lora_rank,
163163
num_slices=self.w13_slices,
164164
M=M,
165165
layer=layer,
@@ -190,6 +190,7 @@ def wrapper(*args, **kwargs):
190190

191191
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
192192
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
193+
#
193194

194195
self.punica_wrapper.add_lora_fused_moe(
195196
input.view(-1, top_k, input.shape[-1]),
@@ -200,7 +201,7 @@ def wrapper(*args, **kwargs):
200201
sorted_token_ids_lora,
201202
expert_ids_lora,
202203
num_tokens_post_padded_lora,
203-
self.max_lora_rank,
204+
max_lora_rank,
204205
top_k,
205206
shrink_config, ## pass the shrink config
206207
expand_config, ## pass the expand config
@@ -229,11 +230,11 @@ def wrapper(*args, **kwargs):
229230
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
230231
num_tokens = hidden_states.size(0)
231232
M = min(num_tokens, CHUNK_SIZE)
232-
233+
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
233234
shrink_config, expand_config = self._get_lora_moe_configs(
234235
op_prefix="w2",
235236
num_loras=self.max_loras,
236-
rank=self.max_lora_rank,
237+
rank=max_lora_rank,
237238
num_slices=1,
238239
M=M,
239240
layer=layer,
@@ -263,7 +264,7 @@ def wrapper(*args, **kwargs):
263264
sorted_token_ids_lora,
264265
expert_ids_lora,
265266
num_tokens_post_padded_lora,
266-
self.max_lora_rank,
267+
max_lora_rank,
267268
top_k,
268269
shrink_config, ## pass the shrink config
269270
expand_config, ## pass the expand config
@@ -300,7 +301,6 @@ def create_lora_weights(
300301
"""Initializes lora matrices."""
301302
assert self.w13_slices == 2
302303
self.max_loras = lora_config.max_loras
303-
self.max_lora_rank = lora_config.max_lora_rank
304304
self.fully_sharded = lora_config.fully_sharded_loras
305305

306306
self.adapter_enabled = torch.tensor(
@@ -312,9 +312,9 @@ def create_lora_weights(
312312
(
313313
max_loras,
314314
self.base_layer.local_num_experts,
315-
self.max_lora_rank
315+
lora_config.max_lora_rank
316316
if not self.fully_sharded
317-
else divide(self.max_lora_rank, self.tp_size),
317+
else divide(lora_config.max_lora_rank, self.tp_size),
318318
self.base_layer.hidden_size,
319319
),
320320
dtype=lora_config.lora_dtype,
@@ -329,7 +329,7 @@ def create_lora_weights(
329329
max_loras,
330330
self.base_layer.local_num_experts,
331331
self.base_layer.intermediate_size_per_partition,
332-
self.max_lora_rank,
332+
lora_config.max_lora_rank,
333333
),
334334
dtype=lora_config.lora_dtype,
335335
device=self.device,
@@ -341,7 +341,7 @@ def create_lora_weights(
341341
(
342342
max_loras,
343343
self.base_layer.local_num_experts,
344-
self.max_lora_rank,
344+
lora_config.max_lora_rank,
345345
self.base_layer.intermediate_size_per_partition,
346346
),
347347
dtype=lora_config.lora_dtype,
@@ -354,7 +354,7 @@ def create_lora_weights(
354354
self.base_layer.hidden_size
355355
if not self.fully_sharded
356356
else divide(self.base_layer.hidden_size, self.tp_size),
357-
self.max_lora_rank,
357+
lora_config.max_lora_rank,
358358
),
359359
dtype=lora_config.lora_dtype,
360360
device=self.device,
@@ -426,7 +426,7 @@ def set_lora(
426426
if self.fully_sharded:
427427
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
428428
# and W2 B along the hidden_size dim.
429-
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
429+
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
430430
w13_start_idx = self.tp_rank * w13_shard_size
431431
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
432432
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]

0 commit comments

Comments
 (0)