Skip to content

Commit 45acc6b

Browse files
committed
Done
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 9875be6 commit 45acc6b

File tree

3 files changed

+95
-100
lines changed

3 files changed

+95
-100
lines changed

vllm/lora/layers/fused_moe.py

Lines changed: 91 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, base_layer: FusedMoE) -> None:
4242
self.tp_size = get_tensor_model_parallel_world_size()
4343
self.tp_rank = get_tensor_model_parallel_rank()
4444
self.device = base_layer.w2_weight.device
45+
self.w13_slices = 2
4546
self._inject_lora_into_fused_moe()
4647

4748
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
@@ -60,32 +61,34 @@ def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None
6061
def _get_lora_moe_configs(
6162
self,
6263
op_prefix: str,
63-
lora_a_stacked: torch.Tensor,
64-
lora_b_stacked: torch.Tensor,
64+
num_loras: int,
65+
rank: int,
6566
num_slices: int,
6667
M: int,
6768
layer: FusedMoE,
6869
top_k: int,
6970
config_dtype: str,
7071
):
7172
if envs.VLLM_TUNED_CONFIG_FOLDER:
73+
hidden_size = layer.hidden_size
74+
intermediate_size = layer.intermediate_size_per_partition
7275
shrink_config = get_lora_op_configs(
7376
op_type=f"fused_moe_lora_{op_prefix}_shrink",
74-
max_loras=lora_a_stacked.shape[0],
77+
max_loras=num_loras,
7578
batch=M,
76-
hidden_size=lora_a_stacked.shape[-1],
77-
rank=lora_a_stacked.shape[-2],
79+
hidden_size=hidden_size,
80+
rank=rank,
7881
num_slices=num_slices,
79-
moe_intermediate_size=lora_b_stacked.shape[-2],
82+
moe_intermediate_size=intermediate_size,
8083
)
8184
expand_config = get_lora_op_configs(
8285
op_type=f"fused_moe_lora_{op_prefix}_expand",
83-
max_loras=lora_a_stacked.shape[0],
86+
max_loras=num_loras,
8487
batch=M,
85-
hidden_size=lora_a_stacked.shape[-1],
86-
rank=lora_a_stacked.shape[-2],
88+
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
89+
rank=rank,
8790
num_slices=num_slices,
88-
moe_intermediate_size=lora_b_stacked.shape[-2],
91+
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
8992
)
9093
else: # fall back to the default config
9194
get_config_func = functools.partial(
@@ -155,17 +158,16 @@ def wrapper(*args, **kwargs):
155158

156159
shrink_config, expand_config = self._get_lora_moe_configs(
157160
op_prefix="w13",
158-
lora_a_stacked=self.w1_lora_a_stacked,
159-
lora_b_stacked=self.w1_lora_b_stacked,
160-
num_slices=2,
161+
num_loras=self.max_loras,
162+
rank=self.max_lora_rank,
163+
num_slices=self.w13_slices,
161164
M=M,
162165
layer=layer,
163166
top_k=top_k,
164167
config_dtype=config_dtype,
165168
)
166169

167170
# get the block size of m from customized config or default config
168-
max_loras = self.w1_lora_a_stacked.shape[0]
169171
(
170172
sorted_token_ids_lora,
171173
expert_ids_lora,
@@ -175,7 +177,7 @@ def wrapper(*args, **kwargs):
175177
num_tokens,
176178
shrink_config["BLOCK_SIZE_M"],
177179
self.base_layer.local_num_experts,
178-
max_loras,
180+
self.max_loras,
179181
self.adapter_enabled,
180182
expert_map,
181183
)
@@ -186,22 +188,19 @@ def wrapper(*args, **kwargs):
186188
num_tokens_post_padded_lora
187189
)
188190

189-
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
190-
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
191-
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
192-
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
193-
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
191+
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
192+
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
194193

195194
self.punica_wrapper.add_lora_fused_moe(
196195
input.view(-1, top_k, input.shape[-1]),
197196
hidden_states,
198-
w13_lora_a_stacked,
199-
w13_lora_b_stacked,
197+
self.w13_lora_a_stacked,
198+
self.w13_lora_b_stacked,
200199
topk_weights,
201200
sorted_token_ids_lora,
202201
expert_ids_lora,
203202
num_tokens_post_padded_lora,
204-
max_lora_rank,
203+
self.max_lora_rank,
205204
top_k,
206205
shrink_config, ## pass the shrink config
207206
expand_config, ## pass the expand config
@@ -233,8 +232,8 @@ def wrapper(*args, **kwargs):
233232

234233
shrink_config, expand_config = self._get_lora_moe_configs(
235234
op_prefix="w2",
236-
lora_a_stacked=self.w2_lora_a_stacked,
237-
lora_b_stacked=self.w2_lora_b_stacked,
235+
num_loras=self.max_loras,
236+
rank=self.max_lora_rank,
238237
num_slices=1,
239238
M=M,
240239
layer=layer,
@@ -247,25 +246,24 @@ def wrapper(*args, **kwargs):
247246
num_tokens_post_padded_lora = moe_state_dict[
248247
"num_tokens_post_padded_lora"
249248
]
250-
max_loras = self.w1_lora_a_stacked.shape[0]
251-
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
252-
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
249+
250+
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
251+
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
253252
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
254253
intermediate_cache3 = args[0]
255-
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
256254

257255
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
258256

259257
self.punica_wrapper.add_lora_fused_moe(
260258
intermediate_cache3,
261259
intermediate_cache2,
262-
[self.w2_lora_a_stacked],
263-
[self.w2_lora_b_stacked],
260+
(self.w2_lora_a_stacked,),
261+
(self.w2_lora_b_stacked,),
264262
topk_weights,
265263
sorted_token_ids_lora,
266264
expert_ids_lora,
267265
num_tokens_post_padded_lora,
268-
max_lora_rank,
266+
self.max_lora_rank,
269267
top_k,
270268
shrink_config, ## pass the shrink config
271269
expand_config, ## pass the expand config
@@ -289,7 +287,6 @@ def wrapper(*args, **kwargs):
289287
fused_experts.moe_sum = moe_sum_decorator(
290288
self.base_layer, fused_experts.moe_sum
291289
)
292-
293290
self.base_layer.quant_method = FusedMoEModularMethod(
294291
self.base_layer.quant_method, m_fused_moe_fn
295292
)
@@ -301,40 +298,50 @@ def create_lora_weights(
301298
model_config: PretrainedConfig | None = None,
302299
) -> None:
303300
"""Initializes lora matrices."""
301+
assert self.w13_slices == 2
302+
self.max_loras = lora_config.max_loras
303+
self.max_lora_rank = lora_config.max_lora_rank
304304
self.fully_sharded = lora_config.fully_sharded_loras
305305

306306
self.adapter_enabled = torch.tensor(
307307
[0] * (max_loras + 1), dtype=torch.int, device=self.device
308308
)
309309

310-
self.w1_lora_a_stacked = torch.zeros(
311-
(
312-
max_loras,
313-
self.base_layer.local_num_experts,
314-
lora_config.max_lora_rank
315-
if not self.fully_sharded
316-
else divide(lora_config.max_lora_rank, self.tp_size),
317-
self.base_layer.hidden_size,
318-
),
319-
dtype=lora_config.lora_dtype,
320-
device=self.device,
310+
self.w13_lora_a_stacked = tuple(
311+
torch.zeros(
312+
(
313+
max_loras,
314+
self.base_layer.local_num_experts,
315+
self.max_lora_rank
316+
if not self.fully_sharded
317+
else divide(self.max_lora_rank, self.tp_size),
318+
self.base_layer.hidden_size,
319+
),
320+
dtype=lora_config.lora_dtype,
321+
device=self.device,
322+
)
323+
for _ in range(self.w13_slices)
321324
)
322-
self.w1_lora_b_stacked = torch.zeros(
323-
(
324-
max_loras,
325-
self.base_layer.local_num_experts,
326-
self.base_layer.intermediate_size_per_partition,
327-
lora_config.max_lora_rank,
328-
),
329-
dtype=lora_config.lora_dtype,
330-
device=self.device,
325+
326+
self.w13_lora_b_stacked = tuple(
327+
torch.zeros(
328+
(
329+
max_loras,
330+
self.base_layer.local_num_experts,
331+
self.base_layer.intermediate_size_per_partition,
332+
self.max_lora_rank,
333+
),
334+
dtype=lora_config.lora_dtype,
335+
device=self.device,
336+
)
337+
for _ in range(self.w13_slices)
331338
)
332339

333340
self.w2_lora_a_stacked = torch.zeros(
334341
(
335342
max_loras,
336343
self.base_layer.local_num_experts,
337-
lora_config.max_lora_rank,
344+
self.max_lora_rank,
338345
self.base_layer.intermediate_size_per_partition,
339346
),
340347
dtype=lora_config.lora_dtype,
@@ -347,30 +354,7 @@ def create_lora_weights(
347354
self.base_layer.hidden_size
348355
if not self.fully_sharded
349356
else divide(self.base_layer.hidden_size, self.tp_size),
350-
lora_config.max_lora_rank,
351-
),
352-
dtype=lora_config.lora_dtype,
353-
device=self.device,
354-
)
355-
356-
self.w3_lora_a_stacked = torch.zeros(
357-
(
358-
max_loras,
359-
self.base_layer.local_num_experts,
360-
lora_config.max_lora_rank
361-
if not self.fully_sharded
362-
else divide(lora_config.max_lora_rank, self.tp_size),
363-
self.base_layer.hidden_size,
364-
),
365-
dtype=lora_config.lora_dtype,
366-
device=self.device,
367-
)
368-
self.w3_lora_b_stacked = torch.zeros(
369-
(
370-
max_loras,
371-
self.base_layer.local_num_experts,
372-
self.base_layer.intermediate_size_per_partition,
373-
lora_config.max_lora_rank,
357+
self.max_lora_rank,
374358
),
375359
dtype=lora_config.lora_dtype,
376360
device=self.device,
@@ -383,20 +367,28 @@ def create_lora_weights(
383367
for lora_id in range(max_loras):
384368
for experts_id in range(self.base_layer.local_num_experts):
385369
# gate_proj,down_proj,up_proj
386-
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
370+
self.lora_a_stacked.append(
371+
self.w13_lora_a_stacked[0][lora_id][experts_id]
372+
)
387373
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
388-
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
374+
self.lora_a_stacked.append(
375+
self.w13_lora_a_stacked[1][lora_id][experts_id]
376+
)
389377

390-
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
378+
self.lora_b_stacked.append(
379+
self.w13_lora_b_stacked[0][lora_id][experts_id]
380+
)
391381
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
392-
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
382+
self.lora_b_stacked.append(
383+
self.w13_lora_b_stacked[1][lora_id][experts_id]
384+
)
393385

394386
def reset_lora(self, index: int):
395387
"""Resets the lora weights at index back to 0."""
396-
self.w1_lora_a_stacked[index] = 0
397-
self.w1_lora_b_stacked[index] = 0
398-
self.w3_lora_a_stacked[index] = 0
399-
self.w3_lora_b_stacked[index] = 0
388+
for pos in range(self.w13_slices):
389+
self.w13_lora_a_stacked[pos][index] = 0
390+
self.w13_lora_b_stacked[pos][index] = 0
391+
400392
self.w2_lora_a_stacked[index] = 0
401393
self.w2_lora_b_stacked[index] = 0
402394
self.adapter_enabled[index] = 0
@@ -444,29 +436,32 @@ def set_lora(
444436
w2_start_idx = self.tp_rank * w2_shard_size
445437
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
446438
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
447-
448-
self.w1_lora_a_stacked[
439+
# w1 lora_a
440+
self.w13_lora_a_stacked[0][
449441
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
450442
].copy_(w1_lora_a, non_blocking=True)
451-
452-
self.w3_lora_a_stacked[
443+
# w3 lora_a
444+
self.w13_lora_a_stacked[1][
453445
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
454446
].copy_(w3_lora_a, non_blocking=True)
455447

456-
self.w2_lora_b_stacked[
457-
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
458-
].copy_(w2_lora_b, non_blocking=True)
459-
460-
self.w1_lora_b_stacked[
448+
# w1 lora_b
449+
self.w13_lora_b_stacked[0][
461450
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
462451
].copy_(w1_lora_b, non_blocking=True)
463-
self.w3_lora_b_stacked[
452+
# w3 lora_b
453+
self.w13_lora_b_stacked[1][
464454
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
465455
].copy_(w3_lora_b, non_blocking=True)
456+
466457
self.w2_lora_a_stacked[
467458
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
468459
].copy_(w2_lora_a, non_blocking=True)
469460

461+
self.w2_lora_b_stacked[
462+
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
463+
].copy_(w2_lora_b, non_blocking=True)
464+
470465
@classmethod
471466
def can_replace_layer(
472467
cls,

vllm/lora/punica_wrapper/punica_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def add_lora_fused_moe(
470470
self,
471471
y: torch.Tensor,
472472
x: torch.Tensor,
473-
lora_a_stacked: list[torch.Tensor],
474-
lora_b_stacked: list[torch.Tensor],
473+
lora_a_stacked: tuple[torch.Tensor, ...],
474+
lora_b_stacked: tuple[torch.Tensor, ...],
475475
topk_weights: torch.Tensor,
476476
sorted_token_ids: torch.Tensor,
477477
expert_ids: torch.Tensor,

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ def add_lora_fused_moe(
360360
self,
361361
y: torch.Tensor,
362362
x: torch.Tensor,
363-
lora_a_stacked: list[torch.Tensor],
364-
lora_b_stacked: list[torch.Tensor],
363+
lora_a_stacked: tuple[torch.Tensor, ...],
364+
lora_b_stacked: tuple[torch.Tensor, ...],
365365
topk_weights: torch.Tensor,
366366
sorted_token_ids: torch.Tensor,
367367
expert_ids: torch.Tensor,

0 commit comments

Comments
 (0)