@@ -91,8 +91,37 @@ def _get_model(initialize_megatron=True):
9191 return model
9292
9393 model = _get_model ()
94+
95+ # Set seeds for deterministic dummy input generation AFTER model initialization
96+ # (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
97+ torch .manual_seed (1234 )
98+ torch .cuda .manual_seed_all (1234 )
99+
94100 sd = model .state_dict ()
95101
102+ # Debug: Print some model weights to verify deterministic initialization
103+ if rank == 0 :
104+ weight_keys = list (sd .keys ())[:10 ] # First 10 weight keys
105+ print ("\n === Model Weight Debug (first 10 keys) ===" )
106+ for key in weight_keys :
107+ weight = sd [key ]
108+ if isinstance (weight , torch .Tensor ) and weight .numel () > 0 :
109+ # Skip non-floating point tensors (e.g., Byte, Int)
110+ if weight .dtype in [torch .float32 , torch .float16 , torch .bfloat16 , torch .float64 ]:
111+ mean = weight .mean ().item ()
112+ std = weight .std ().item ()
113+ min_val = weight .min ().item ()
114+ max_val = weight .max ().item ()
115+ print (
116+ f"{ key } : shape={ weight .shape } , "
117+ f"mean={ mean :.10f} , std={ std :.10f} , min={ min_val :.10f} , max={ max_val :.10f} "
118+ )
119+ else :
120+ first_vals = weight .flatten ()[:5 ].tolist ()
121+ print (f"{ key } : shape={ weight .shape } , dtype={ weight .dtype } " )
122+ print (f" (non-float, first 5 values: { first_vals } )" )
123+ print ("=" * 50 + "\n " )
124+
96125 def forward_loop (m ):
97126 for _ in range (5 ):
98127 run_mcore_inference_with_dummy_input (m , batch_size , hidden_size )
@@ -242,14 +271,14 @@ def forward_loop(m):
242271 [
243272 # MHA - pruned ffn/4
244273 (8 , 8 , "squared_relu" , "LayerNorm" , 4 , 1 , 1 , 1 , 1 , False , "rope" , False , False ),
245- # GQA - pruned attention/2
246- (8 , 4 , "squared_relu" , "RMSNorm" , 1 , 2 , 2 , 1 , 1 , False , "rope" , False , False ),
247- # GQA - pruned hidden_size/4
248- (8 , 4 , "swiglu" , "RMSNorm" , 1 , 1 , 1 , 4 , 1 , False , "rope" , True , False ),
249- # MHA - pruned num_layers/2
250- (8 , 8 , "swiglu" , "LayerNorm" , 1 , 1 , 1 , 1 , 2 , False , "rope" , False , False ),
251- # GQA - pruned all/2, uneven pp
252- (8 , 4 , "swiglu" , "RMSNorm" , 2 , 2 , 2 , 2 , 2 , True , "yarn" , False , True ),
274+ # # GQA - pruned attention/2
275+ # (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
276+ # # GQA - pruned hidden_size/4
277+ # (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
278+ # # MHA - pruned num_layers/2
279+ # (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
280+ # # GQA - pruned all/2, uneven pp
281+ # (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
253282 ],
254283)
255284def test_mcore_gpt_pruning (
0 commit comments