Skip to content

Commit f1c4e65

Browse files
committed
quick fix
Signed-off-by: Che Ruan <[email protected]>
1 parent 9ebad19 commit f1c4e65

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
194194
json.dump(record, f, indent=4)
195195

196196
def do_update_expert_map(self, layer_id, updated_expert_map):
197-
self.expert_map_per_layer[layer_id] = updated_expert_map.clone()
198-
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()
197+
pad_len = self.expert_map_per_layer[layer_id].shape[0] - updated_expert_map.shape[0]
198+
updated_expert_map_padded = torch.nn.functional.pad(
199+
updated_expert_map,
200+
pad=(0,pad_len),
201+
mode='constant',
202+
value=-1
203+
)
204+
self.expert_map_per_layer[layer_id].copy_(updated_expert_map_padded)
205+
self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map)
199206

200207
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
201208
buffer_tensor_id):
202209
for expert_tensor, buffer_tensor in zip(
203210
self.expert_param_per_layer[layer_id][local_expert_to_replace],
204211
self.buffer_tensor_list[buffer_tensor_id]):
205-
expert_tensor = buffer_tensor.clone()
212+
expert_tensor.copy_(buffer_tensor)
206213
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
207214

208215
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
209216
if self.log2phy_map_per_layer[layer_id] is not None:
210-
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
217+
pad_len = self.log2phy_map_per_layer[layer_id].shape[0] - updated_log2phy_map.shape[0]
218+
updated_log2phy_map_padded = torch.nn.functional.pad(
219+
updated_log2phy_map,
220+
pad=(0,pad_len),
221+
mode='constant',
222+
value=-1
223+
)
224+
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map_padded)
211225

212226
def global2local(self, placement: torch.Tensor,
213227
E_local: int) -> torch.Tensor:

0 commit comments

Comments
 (0)