@@ -194,20 +194,34 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
194194 json .dump (record , f , indent = 4 )
195195
196196 def do_update_expert_map (self , layer_id , updated_expert_map ):
197- self .expert_map_per_layer [layer_id ] = updated_expert_map .clone ()
198- self .expert_map_per_layer_cpu [layer_id ] = updated_expert_map .clone ()
197+ pad_len = self .expert_map_per_layer [layer_id ].shape [0 ] - updated_expert_map .shape [0 ]
198+ updated_expert_map_padded = torch .nn .functional .pad (
199+ updated_expert_map ,
200+ pad = (0 ,pad_len ),
201+ mode = 'constant' ,
202+ value = - 1
203+ )
204+ self .expert_map_per_layer [layer_id ].copy_ (updated_expert_map_padded )
205+ self .expert_map_per_layer_cpu [layer_id ].copy_ (updated_expert_map )
199206
200207 def do_update_expert_weight (self , layer_id , local_expert_to_replace ,
201208 buffer_tensor_id ):
202209 for expert_tensor , buffer_tensor in zip (
203210 self .expert_param_per_layer [layer_id ][local_expert_to_replace ],
204211 self .buffer_tensor_list [buffer_tensor_id ]):
205- expert_tensor = buffer_tensor . clone ( )
212+ expert_tensor . copy_ ( buffer_tensor )
206213 logger .debug (f"Expert tensor shape is :{ expert_tensor .shape } " )
207214
208215 def do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
209216 if self .log2phy_map_per_layer [layer_id ] is not None :
210- self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map )
217+ pad_len = self .log2phy_map_per_layer [layer_id ].shape [0 ] - updated_log2phy_map .shape [0 ]
218+ updated_log2phy_map_padded = torch .nn .functional .pad (
219+ updated_log2phy_map ,
220+ pad = (0 ,pad_len ),
221+ mode = 'constant' ,
222+ value = - 1
223+ )
224+ self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map_padded )
211225
212226 def global2local (self , placement : torch .Tensor ,
213227 E_local : int ) -> torch .Tensor :
0 commit comments