@@ -64,7 +64,6 @@ def __init__(self, model, mtp_instance, num_mtp_layers, **args):
6464 ]
6565 else :
6666 self .mtp_expert_weight_names = ["w13_weight" , "w2_weight" ]
67-
6867
6968 self .expert_map_per_layer = dict (
7069 ) # reference to expert map on device for expert map update
@@ -79,7 +78,7 @@ def __init__(self, model, mtp_instance, num_mtp_layers, **args):
7978 for mtp_layer_idx in range (self .num_mtp_layers ):
8079 self .expert_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
8180 self .mtp_instance .model .get_expert_map (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx )
82-
81+
8382 # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
8483 num_buffer_tensor = torch .where (
8584 self .expert_map_per_layer [self .num_dense_layers ] != - 1 )[0 ].numel ()
@@ -95,7 +94,7 @@ def __init__(self, model, mtp_instance, num_mtp_layers, **args):
9594 for layer_idx in range (self .num_moe_layers ):
9695 self .log2phy_map_per_layer [self .num_dense_layers + layer_idx ] = \
9796 self .model .get_log2phy_map (self .num_dense_layers + layer_idx )
98-
97+
9998 if self .mtp_instance is not None :
10099 for mtp_layer_idx in range (self .num_mtp_layers ):
101100 self .log2phy_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
@@ -127,7 +126,7 @@ def init_expert_param_per_layer(self):
127126 name ].data [local_expert_id ]
128127 for name in self .expert_weight_names
129128 ])
130-
129+
131130 if self .mtp_instance is not None :
132131 mtp_param_dict = dict (self .mtp_instance .named_parameters ())
133132 self .expert_param_per_layer [self .num_dense_layers +
@@ -153,8 +152,7 @@ def get_rank_expert_workload(self) -> torch.Tensor:
153152 self .moe_load ,
154153 self .mtp_instance .model .get_all_moe_loads ().to (
155154 device = self .moe_load .device )
156- ],
157- dim = 0 )
155+ ], dim = 0 )
158156 return self .moe_load
159157
160158 def get_init_expert_map (self , num_moe_layers ):
@@ -164,8 +162,7 @@ def get_init_expert_map(self, num_moe_layers):
164162 expert_map ,
165163 self .mtp_instance .model .get_all_expert_map ().to (
166164 device = expert_map .device )
167- ],
168- dim = 0 )
165+ ], dim = 0 )
169166 if dist .is_initialized ():
170167 world_size = dist .get_world_size ()
171168
0 commit comments