@@ -93,8 +93,25 @@ def _get_model(initialize_megatron=True):
9393
9494 model = _get_model ()
9595
96+ # Set seeds for deterministic dummy input generation AFTER model initialization
97+ # (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
98+ torch .manual_seed (1234 )
99+ torch .cuda .manual_seed_all (1234 )
100+ # Enable deterministic behavior for cuDNN
101+ torch .backends .cudnn .deterministic = True
102+ torch .backends .cudnn .benchmark = False
103+
96104 sd = model .state_dict ()
97105
106+ # Debug: Check weight initialization
107+ if rank == 0 :
108+ print ("\n === Weight Initialization Check ===" )
109+ qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
110+ if qkv_key in sd :
111+ qkv_weight = sd [qkv_key ]
112+ print (f"{ qkv_key } : mean={ qkv_weight .mean ().item ():.16f} " )
113+ print ("=" * 50 + "\n " )
114+
98115 def forward_loop (m ):
99116 for _ in range (5 ):
100117 run_mcore_inference_with_dummy_input (m , batch_size , hidden_size )
@@ -145,17 +162,30 @@ def forward_loop(m):
145162 # Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
146163 if pruned_ffn_div == 4 :
147164 # Layer scores
165+ if rank == 0 :
166+ print ("\n === TEST CASE 1 ===" )
167+ print (f"layer_scores[1] = { pruning_scores ['layer_scores' ][1 ]:.16f} " )
168+ print (f"layer_scores[2] = { pruning_scores ['layer_scores' ][2 ]:.16f} " )
148169 assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.0868452191352844 , abs = 1e-5 )
149170 assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7638601660728455 , abs = 1e-5 )
150171
151172 # Validate decoder.layers.0.mlp activations
152173 mlp_0_acts = rank_0_activations ["decoder.layers.0.mlp" ]
174+ if rank == 0 :
175+ print (f"mlp_0_acts.min() = { mlp_0_acts .min ().item ():.16f} " )
176+ print (f"mlp_0_acts.max() = { mlp_0_acts .max ().item ():.16f} " )
177+ print (f"mlp_0_acts.mean() = { mlp_0_acts .mean ().item ():.16f} " )
153178 assert mlp_0_acts .min ().item () == pytest .approx (0.0015609927941114 , abs = 1e-5 )
154179 assert mlp_0_acts .max ().item () == pytest .approx (0.3844809532165527 , abs = 1e-5 )
155180 assert mlp_0_acts .mean ().item () == pytest .approx (0.0629318505525589 , abs = 1e-5 )
156181
157182 # Validate decoder.layers.1.mlp activations
158183 mlp_1_acts = rank_0_activations ["decoder.layers.1.mlp" ]
184+ if rank == 0 :
185+ print (f"mlp_1_acts.min() = { mlp_1_acts .min ().item ():.16f} " )
186+ print (f"mlp_1_acts.max() = { mlp_1_acts .max ().item ():.16f} " )
187+ print (f"mlp_1_acts.mean() = { mlp_1_acts .mean ().item ():.16f} " )
188+ print ("=" * 50 + "\n " )
159189 assert mlp_1_acts .min ().item () == pytest .approx (0.0001484956446802 , abs = 1e-5 )
160190 assert mlp_1_acts .max ().item () == pytest .approx (0.7835369110107422 , abs = 1e-5 )
161191 assert mlp_1_acts .mean ().item () == pytest .approx (0.0926810950040817 , abs = 1e-5 )
@@ -244,14 +274,14 @@ def forward_loop(m):
244274 [
245275 # MHA - pruned ffn/4
246276 (8 , 8 , "squared_relu" , "LayerNorm" , 4 , 1 , 1 , 1 , 1 , False , "rope" , False , False ),
247- # GQA - pruned attention/2
248- (8 , 4 , "squared_relu" , "RMSNorm" , 1 , 2 , 2 , 1 , 1 , False , "rope" , False , False ),
249- # GQA - pruned hidden_size/4
250- (8 , 4 , "swiglu" , "RMSNorm" , 1 , 1 , 1 , 4 , 1 , False , "rope" , True , False ),
251- # MHA - pruned num_layers/2
252- (8 , 8 , "swiglu" , "LayerNorm" , 1 , 1 , 1 , 1 , 2 , False , "rope" , False , False ),
253- # GQA - pruned all/2, uneven pp
254- (8 , 4 , "swiglu" , "RMSNorm" , 2 , 2 , 2 , 2 , 2 , True , "yarn" , False , True ),
277+ # # GQA - pruned attention/2
278+ # (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
279+ # # GQA - pruned hidden_size/4
280+ # (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
281+ # # MHA - pruned num_layers/2
282+ # (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
283+ # # GQA - pruned all/2, uneven pp
284+ # (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
255285 ],
256286)
257287def test_mcore_gpt_pruning (
0 commit comments