@@ -103,26 +103,46 @@ def _get_model(initialize_megatron=True):
103103 sd = model .state_dict ()
104104
105105 # Debug: Print some model weights to verify deterministic initialization
106+ # if rank == 0:
107+ # weight_keys = list(sd.keys())[:10] # First 10 weight keys
108+ # print("\n=== Model Weight Debug (first 10 keys) ===")
109+ # for key in weight_keys:
110+ # weight = sd[key]
111+ # if isinstance(weight, torch.Tensor) and weight.numel() > 0:
112+ # # Skip non-floating point tensors (e.g., Byte, Int)
113+ # if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
114+ # mean = weight.mean().item()
115+ # std = weight.std().item()
116+ # min_val = weight.min().item()
117+ # max_val = weight.max().item()
118+ # print(
119+ # f"{key}: shape={weight.shape}, "
120+ # f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
121+ # )
122+ # else:
123+ # first_vals = weight.flatten()[:5].tolist()
124+ # print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
125+ # print(f" (non-float, first 5 values: {first_vals})")
126+ # print("=" * 50 + "\n")
127+
128+ # Debug: Check if reinitializing produces same weights
106129 if rank == 0 :
107- weight_keys = list (sd .keys ())[:10 ] # First 10 weight keys
108- print ("\n === Model Weight Debug (first 10 keys) ===" )
109- for key in weight_keys :
110- weight = sd [key ]
111- if isinstance (weight , torch .Tensor ) and weight .numel () > 0 :
112- # Skip non-floating point tensors (e.g., Byte, Int)
113- if weight .dtype in [torch .float32 , torch .float16 , torch .bfloat16 , torch .float64 ]:
114- mean = weight .mean ().item ()
115- std = weight .std ().item ()
116- min_val = weight .min ().item ()
117- max_val = weight .max ().item ()
118- print (
119- f"{ key } : shape={ weight .shape } , "
120- f"mean={ mean :.10f} , std={ std :.10f} , min={ min_val :.10f} , max={ max_val :.10f} "
121- )
122- else :
123- first_vals = weight .flatten ()[:5 ].tolist ()
124- print (f"{ key } : shape={ weight .shape } , dtype={ weight .dtype } " )
125- print (f" (non-float, first 5 values: { first_vals } )" )
130+ print ("\n === Checking Weight Initialization Determinism ===" )
131+ # Save current linear_qkv weight
132+ qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
133+ proj_key = "decoder.layers.0.self_attention.linear_proj.weight"
134+
135+ if qkv_key in sd and proj_key in sd :
136+ qkv_weight = sd [qkv_key ].clone ()
137+ proj_weight = sd [proj_key ].clone ()
138+ print (f"{ qkv_key } :" )
139+ print (f" shape={ qkv_weight .shape } , mean={ qkv_weight .mean ().item ():.10f} " )
140+ print (f" device={ qkv_weight .device } , dtype={ qkv_weight .dtype } " )
141+ print (f" is_contiguous={ qkv_weight .is_contiguous ()} " )
142+ print (f"{ proj_key } :" )
143+ print (f" shape={ proj_weight .shape } , mean={ proj_weight .mean ().item ():.10f} " )
144+ print (f" device={ proj_weight .device } , dtype={ proj_weight .dtype } " )
145+ print (f" is_contiguous={ proj_weight .is_contiguous ()} " )
126146 print ("=" * 50 + "\n " )
127147
128148 def forward_loop (m ):
0 commit comments