Skip to content

Commit b5de20a

Browse files
debugging
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 5bdb08b commit b5de20a

File tree

1 file changed

+37
-8
lines changed

1 file changed

+37
-8
lines changed

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,37 @@ def _get_model(initialize_megatron=True):
9191
return model
9292

9393
model = _get_model()
94+
95+
# Set seeds for deterministic dummy input generation AFTER model initialization
96+
# (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
97+
torch.manual_seed(1234)
98+
torch.cuda.manual_seed_all(1234)
99+
94100
sd = model.state_dict()
95101

102+
# Debug: Print some model weights to verify deterministic initialization
103+
if rank == 0:
104+
weight_keys = list(sd.keys())[:10] # First 10 weight keys
105+
print("\n=== Model Weight Debug (first 10 keys) ===")
106+
for key in weight_keys:
107+
weight = sd[key]
108+
if isinstance(weight, torch.Tensor) and weight.numel() > 0:
109+
# Skip non-floating point tensors (e.g., Byte, Int)
110+
if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
111+
mean = weight.mean().item()
112+
std = weight.std().item()
113+
min_val = weight.min().item()
114+
max_val = weight.max().item()
115+
print(
116+
f"{key}: shape={weight.shape}, "
117+
f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
118+
)
119+
else:
120+
first_vals = weight.flatten()[:5].tolist()
121+
print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
122+
print(f" (non-float, first 5 values: {first_vals})")
123+
print("=" * 50 + "\n")
124+
96125
def forward_loop(m):
97126
for _ in range(5):
98127
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
@@ -242,14 +271,14 @@ def forward_loop(m):
242271
[
243272
# MHA - pruned ffn/4
244273
(8, 8, "squared_relu", "LayerNorm", 4, 1, 1, 1, 1, False, "rope", False, False),
245-
# GQA - pruned attention/2
246-
(8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
247-
# GQA - pruned hidden_size/4
248-
(8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
249-
# MHA - pruned num_layers/2
250-
(8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
251-
# GQA - pruned all/2, uneven pp
252-
(8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
274+
# # GQA - pruned attention/2
275+
# (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
276+
# # GQA - pruned hidden_size/4
277+
# (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
278+
# # MHA - pruned num_layers/2
279+
# (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
280+
# # GQA - pruned all/2, uneven pp
281+
# (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
253282
],
254283
)
255284
def test_mcore_gpt_pruning(

0 commit comments

Comments
 (0)