Skip to content

Commit f5b85bf

Browse files
debugging
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent dedc036 commit f5b85bf

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,25 @@ def _get_model(initialize_megatron=True):
9393

9494
model = _get_model()
9595

96+
# Set seeds for deterministic dummy input generation AFTER model initialization
97+
# (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
98+
torch.manual_seed(1234)
99+
torch.cuda.manual_seed_all(1234)
100+
# Enable deterministic behavior for cuDNN
101+
torch.backends.cudnn.deterministic = True
102+
torch.backends.cudnn.benchmark = False
103+
96104
sd = model.state_dict()
97105

106+
# Debug: Check weight initialization
107+
if rank == 0:
108+
print("\n=== Weight Initialization Check ===")
109+
qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
110+
if qkv_key in sd:
111+
qkv_weight = sd[qkv_key]
112+
print(f"{qkv_key}: mean={qkv_weight.mean().item():.16f}")
113+
print("=" * 50 + "\n")
114+
98115
def forward_loop(m):
99116
for _ in range(5):
100117
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
@@ -145,17 +162,30 @@ def forward_loop(m):
145162
# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
146163
if pruned_ffn_div == 4:
147164
# Layer scores
165+
if rank == 0:
166+
print("\n=== TEST CASE 1 ===")
167+
print(f"layer_scores[1] = {pruning_scores['layer_scores'][1]:.16f}")
168+
print(f"layer_scores[2] = {pruning_scores['layer_scores'][2]:.16f}")
148169
assert pruning_scores["layer_scores"][1] == pytest.approx(2.0868452191352844, abs=1e-5)
149170
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-5)
150171

151172
# Validate decoder.layers.0.mlp activations
152173
mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]
174+
if rank == 0:
175+
print(f"mlp_0_acts.min() = {mlp_0_acts.min().item():.16f}")
176+
print(f"mlp_0_acts.max() = {mlp_0_acts.max().item():.16f}")
177+
print(f"mlp_0_acts.mean() = {mlp_0_acts.mean().item():.16f}")
153178
assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-5)
154179
assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-5)
155180
assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-5)
156181

157182
# Validate decoder.layers.1.mlp activations
158183
mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]
184+
if rank == 0:
185+
print(f"mlp_1_acts.min() = {mlp_1_acts.min().item():.16f}")
186+
print(f"mlp_1_acts.max() = {mlp_1_acts.max().item():.16f}")
187+
print(f"mlp_1_acts.mean() = {mlp_1_acts.mean().item():.16f}")
188+
print("=" * 50 + "\n")
159189
assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-5)
160190
assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-5)
161191
assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-5)
@@ -244,14 +274,14 @@ def forward_loop(m):
244274
[
245275
# MHA - pruned ffn/4
246276
(8, 8, "squared_relu", "LayerNorm", 4, 1, 1, 1, 1, False, "rope", False, False),
247-
# GQA - pruned attention/2
248-
(8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
249-
# GQA - pruned hidden_size/4
250-
(8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
251-
# MHA - pruned num_layers/2
252-
(8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
253-
# GQA - pruned all/2, uneven pp
254-
(8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
277+
# # GQA - pruned attention/2
278+
# (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
279+
# # GQA - pruned hidden_size/4
280+
# (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
281+
# # MHA - pruned num_layers/2
282+
# (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
283+
# # GQA - pruned all/2, uneven pp
284+
# (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
255285
],
256286
)
257287
def test_mcore_gpt_pruning(

0 commit comments

Comments
 (0)