@@ -93,59 +93,8 @@ def _get_model(initialize_megatron=True):
9393
9494 model = _get_model ()
9595
96- # Set seeds for deterministic dummy input generation AFTER model initialization
97- # (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
98- torch .manual_seed (1234 )
99- torch .cuda .manual_seed_all (1234 )
100- # Enable deterministic behavior for cuDNN
101- torch .backends .cudnn .deterministic = True
102- torch .backends .cudnn .benchmark = False
103-
10496 sd = model .state_dict ()
10597
106- # Debug: Print some model weights to verify deterministic initialization
107- # if rank == 0:
108- # weight_keys = list(sd.keys())[:10] # First 10 weight keys
109- # print("\n=== Model Weight Debug (first 10 keys) ===")
110- # for key in weight_keys:
111- # weight = sd[key]
112- # if isinstance(weight, torch.Tensor) and weight.numel() > 0:
113- # # Skip non-floating point tensors (e.g., Byte, Int)
114- # if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
115- # mean = weight.mean().item()
116- # std = weight.std().item()
117- # min_val = weight.min().item()
118- # max_val = weight.max().item()
119- # print(
120- # f"{key}: shape={weight.shape}, "
121- # f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
122- # )
123- # else:
124- # first_vals = weight.flatten()[:5].tolist()
125- # print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
126- # print(f" (non-float, first 5 values: {first_vals})")
127- # print("=" * 50 + "\n")
128-
129- # Debug: Check if reinitializing produces same weights
130- if rank == 0 :
131- print ("\n === Checking Weight Initialization Determinism ===" )
132- # Save current linear_qkv weight
133- qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
134- proj_key = "decoder.layers.0.self_attention.linear_proj.weight"
135-
136- if qkv_key in sd and proj_key in sd :
137- qkv_weight = sd [qkv_key ].clone ()
138- proj_weight = sd [proj_key ].clone ()
139- print (f"{ qkv_key } :" )
140- print (f" shape={ qkv_weight .shape } , mean={ qkv_weight .mean ().item ():.10f} " )
141- print (f" device={ qkv_weight .device } , dtype={ qkv_weight .dtype } " )
142- print (f" is_contiguous={ qkv_weight .is_contiguous ()} " )
143- print (f"{ proj_key } :" )
144- print (f" shape={ proj_weight .shape } , mean={ proj_weight .mean ().item ():.10f} " )
145- print (f" device={ proj_weight .device } , dtype={ proj_weight .dtype } " )
146- print (f" is_contiguous={ proj_weight .is_contiguous ()} " )
147- print ("=" * 50 + "\n " )
148-
14998 def forward_loop (m ):
15099 for _ in range (5 ):
151100 run_mcore_inference_with_dummy_input (m , batch_size , hidden_size )
@@ -196,42 +145,42 @@ def forward_loop(m):
196145 # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
197146 if pruned_ffn_div == 4 :
198147 # Layer scores
199- assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.1437832713127136 , abs = 1e-5 )
200- assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.792158305644989 , abs = 1e-5 )
148+ assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.0868452191352844 , abs = 1e-5 )
149+ assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7638601660728455 , abs = 1e-5 )
201150
202151 # Validate decoder.layers.0.mlp activations
203152 mlp_0_acts = rank_0_activations ["decoder.layers.0.mlp" ]
204- assert mlp_0_acts .min ().item () == pytest .approx (0.0011843212 , abs = 1e-5 )
205- assert mlp_0_acts .max ().item () == pytest .approx (1.0846971273 , abs = 1e-5 )
206- assert mlp_0_acts .mean ().item () == pytest .approx (0.0535472594 , abs = 1e-5 )
153+ assert mlp_0_acts .min ().item () == pytest .approx (0.0015609927941114 , abs = 1e-5 )
154+ assert mlp_0_acts .max ().item () == pytest .approx (0.3844809532165527 , abs = 1e-5 )
155+ assert mlp_0_acts .mean ().item () == pytest .approx (0.0629318505525589 , abs = 1e-5 )
207156
208157 # Validate decoder.layers.1.mlp activations
209158 mlp_1_acts = rank_0_activations ["decoder.layers.1.mlp" ]
210- assert mlp_1_acts .min ().item () == pytest .approx (0.0002450741 , abs = 1e-5 )
211- assert mlp_1_acts .max ().item () == pytest .approx (1.1014972925 , abs = 1e-5 )
212- assert mlp_1_acts .mean ().item () == pytest .approx (0.0904172808 , abs = 1e-5 )
159+ assert mlp_1_acts .min ().item () == pytest .approx (0.0001484956446802 , abs = 1e-5 )
160+ assert mlp_1_acts .max ().item () == pytest .approx (0.7835369110107422 , abs = 1e-5 )
161+ assert mlp_1_acts .mean ().item () == pytest .approx (0.0926810950040817 , abs = 1e-5 )
213162
214163 # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
215164 elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1 :
216165 # Layer scores
217- assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.1119985580444336 , abs = 1e-5 )
218- assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7729830741882324 , abs = 1e-5 )
166+ assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.1415508985519409 , abs = 1e-5 )
167+ assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7198008894920349 , abs = 1e-5 )
219168
220169 # Validate decoder.layers.0.self_attention activations
221170 assert "decoder.layers.0.self_attention" in rank_0_activations
222171 attn_0_acts = rank_0_activations ["decoder.layers.0.self_attention" ]
223172 assert attn_0_acts .shape == torch .Size ([256 ])
224- assert attn_0_acts .min ().item () == pytest .approx (0.03729403391480446 , abs = 1e-5 )
225- assert attn_0_acts .max ().item () == pytest .approx (0.3653244972229004 , abs = 1e-5 )
226- assert attn_0_acts .mean ().item () == pytest .approx (0.15008458495140076 , abs = 1e-5 )
173+ assert attn_0_acts .min ().item () == pytest .approx (0.0409194342792034 , abs = 1e-5 )
174+ assert attn_0_acts .max ().item () == pytest .approx (0.5261313319206238 , abs = 1e-5 )
175+ assert attn_0_acts .mean ().item () == pytest .approx (0.1613342612981796 , abs = 1e-5 )
227176
228177 # Validate decoder.layers.1.self_attention activations
229178 assert "decoder.layers.1.self_attention" in rank_0_activations
230179 attn_1_acts = rank_0_activations ["decoder.layers.1.self_attention" ]
231180 assert attn_1_acts .shape == torch .Size ([256 ])
232- assert attn_1_acts .min ().item () == pytest .approx (0.140824556350708 , abs = 1e-5 )
233- assert attn_1_acts .max ().item () == pytest .approx (1.0845409631729126 , abs = 1e-5 )
234- assert attn_1_acts .mean ().item () == pytest .approx (0.4730667173862457 , abs = 1e-5 )
181+ assert attn_1_acts .min ().item () == pytest .approx (0.1189328655600548 , abs = 1e-5 )
182+ assert attn_1_acts .max ().item () == pytest .approx (1.3832759857177734 , abs = 1e-5 )
183+ assert attn_1_acts .mean ().item () == pytest .approx (0.4782669544219971 , abs = 1e-5 )
235184
236185 # Assert weights are pruned correctly
237186 for layer in model .decoder .layers :
@@ -295,14 +244,14 @@ def forward_loop(m):
295244 [
296245 # MHA - pruned ffn/4
297246 (8 , 8 , "squared_relu" , "LayerNorm" , 4 , 1 , 1 , 1 , 1 , False , "rope" , False , False ),
298- # # GQA - pruned attention/2
299- # (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
300- # # GQA - pruned hidden_size/4
301- # (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
302- # # MHA - pruned num_layers/2
303- # (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
304- # # GQA - pruned all/2, uneven pp
305- # (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
247+ # GQA - pruned attention/2
248+ (8 , 4 , "squared_relu" , "RMSNorm" , 1 , 2 , 2 , 1 , 1 , False , "rope" , False , False ),
249+ # GQA - pruned hidden_size/4
250+ (8 , 4 , "swiglu" , "RMSNorm" , 1 , 1 , 1 , 4 , 1 , False , "rope" , True , False ),
251+ # MHA - pruned num_layers/2
252+ (8 , 8 , "swiglu" , "LayerNorm" , 1 , 1 , 1 , 1 , 2 , False , "rope" , False , False ),
253+ # GQA - pruned all/2, uneven pp
254+ (8 , 4 , "swiglu" , "RMSNorm" , 2 , 2 , 2 , 2 , 2 , True , "yarn" , False , True ),
306255 ],
307256)
308257def test_mcore_gpt_pruning (
0 commit comments