Skip to content

Commit 675dca4

Browse files
debugging
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 889eb4b commit 675dca4

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,26 +103,46 @@ def _get_model(initialize_megatron=True):
103103
sd = model.state_dict()
104104

105105
# Debug: Print some model weights to verify deterministic initialization
106+
# if rank == 0:
107+
# weight_keys = list(sd.keys())[:10] # First 10 weight keys
108+
# print("\n=== Model Weight Debug (first 10 keys) ===")
109+
# for key in weight_keys:
110+
# weight = sd[key]
111+
# if isinstance(weight, torch.Tensor) and weight.numel() > 0:
112+
# # Skip non-floating point tensors (e.g., Byte, Int)
113+
# if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
114+
# mean = weight.mean().item()
115+
# std = weight.std().item()
116+
# min_val = weight.min().item()
117+
# max_val = weight.max().item()
118+
# print(
119+
# f"{key}: shape={weight.shape}, "
120+
# f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
121+
# )
122+
# else:
123+
# first_vals = weight.flatten()[:5].tolist()
124+
# print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
125+
# print(f" (non-float, first 5 values: {first_vals})")
126+
# print("=" * 50 + "\n")
127+
128+
# Debug: Check if reinitializing produces same weights
106129
if rank == 0:
107-
weight_keys = list(sd.keys())[:10] # First 10 weight keys
108-
print("\n=== Model Weight Debug (first 10 keys) ===")
109-
for key in weight_keys:
110-
weight = sd[key]
111-
if isinstance(weight, torch.Tensor) and weight.numel() > 0:
112-
# Skip non-floating point tensors (e.g., Byte, Int)
113-
if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
114-
mean = weight.mean().item()
115-
std = weight.std().item()
116-
min_val = weight.min().item()
117-
max_val = weight.max().item()
118-
print(
119-
f"{key}: shape={weight.shape}, "
120-
f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
121-
)
122-
else:
123-
first_vals = weight.flatten()[:5].tolist()
124-
print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
125-
print(f" (non-float, first 5 values: {first_vals})")
130+
print("\n=== Checking Weight Initialization Determinism ===")
131+
# Save current linear_qkv weight
132+
qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
133+
proj_key = "decoder.layers.0.self_attention.linear_proj.weight"
134+
135+
if qkv_key in sd and proj_key in sd:
136+
qkv_weight = sd[qkv_key].clone()
137+
proj_weight = sd[proj_key].clone()
138+
print(f"{qkv_key}:")
139+
print(f" shape={qkv_weight.shape}, mean={qkv_weight.mean().item():.10f}")
140+
print(f" device={qkv_weight.device}, dtype={qkv_weight.dtype}")
141+
print(f" is_contiguous={qkv_weight.is_contiguous()}")
142+
print(f"{proj_key}:")
143+
print(f" shape={proj_weight.shape}, mean={proj_weight.mean().item():.10f}")
144+
print(f" device={proj_weight.device}, dtype={proj_weight.dtype}")
145+
print(f" is_contiguous={proj_weight.is_contiguous()}")
126146
print("=" * 50 + "\n")
127147

128148
def forward_loop(m):

0 commit comments

Comments
 (0)