Skip to content

Commit dedc036

Browse files
Fix broken unit tests -initialize model weights on CPU.
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent 9526a0d commit dedc036

File tree

1 file changed

+24
-75
lines changed

1 file changed

+24
-75
lines changed

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 24 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -93,59 +93,8 @@ def _get_model(initialize_megatron=True):
9393

9494
model = _get_model()
9595

96-
# Set seeds for deterministic dummy input generation AFTER model initialization
97-
# (get_mcore_gpt_model calls initialize_for_megatron which sets seed=1234)
98-
torch.manual_seed(1234)
99-
torch.cuda.manual_seed_all(1234)
100-
# Enable deterministic behavior for cuDNN
101-
torch.backends.cudnn.deterministic = True
102-
torch.backends.cudnn.benchmark = False
103-
10496
sd = model.state_dict()
10597

106-
# Debug: Print some model weights to verify deterministic initialization
107-
# if rank == 0:
108-
# weight_keys = list(sd.keys())[:10] # First 10 weight keys
109-
# print("\n=== Model Weight Debug (first 10 keys) ===")
110-
# for key in weight_keys:
111-
# weight = sd[key]
112-
# if isinstance(weight, torch.Tensor) and weight.numel() > 0:
113-
# # Skip non-floating point tensors (e.g., Byte, Int)
114-
# if weight.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.float64]:
115-
# mean = weight.mean().item()
116-
# std = weight.std().item()
117-
# min_val = weight.min().item()
118-
# max_val = weight.max().item()
119-
# print(
120-
# f"{key}: shape={weight.shape}, "
121-
# f"mean={mean:.10f}, std={std:.10f}, min={min_val:.10f}, max={max_val:.10f}"
122-
# )
123-
# else:
124-
# first_vals = weight.flatten()[:5].tolist()
125-
# print(f"{key}: shape={weight.shape}, dtype={weight.dtype}")
126-
# print(f" (non-float, first 5 values: {first_vals})")
127-
# print("=" * 50 + "\n")
128-
129-
# Debug: Check if reinitializing produces same weights
130-
if rank == 0:
131-
print("\n=== Checking Weight Initialization Determinism ===")
132-
# Save current linear_qkv weight
133-
qkv_key = "decoder.layers.0.self_attention.linear_qkv.weight"
134-
proj_key = "decoder.layers.0.self_attention.linear_proj.weight"
135-
136-
if qkv_key in sd and proj_key in sd:
137-
qkv_weight = sd[qkv_key].clone()
138-
proj_weight = sd[proj_key].clone()
139-
print(f"{qkv_key}:")
140-
print(f" shape={qkv_weight.shape}, mean={qkv_weight.mean().item():.10f}")
141-
print(f" device={qkv_weight.device}, dtype={qkv_weight.dtype}")
142-
print(f" is_contiguous={qkv_weight.is_contiguous()}")
143-
print(f"{proj_key}:")
144-
print(f" shape={proj_weight.shape}, mean={proj_weight.mean().item():.10f}")
145-
print(f" device={proj_weight.device}, dtype={proj_weight.dtype}")
146-
print(f" is_contiguous={proj_weight.is_contiguous()}")
147-
print("=" * 50 + "\n")
148-
14998
def forward_loop(m):
15099
for _ in range(5):
151100
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
@@ -196,42 +145,42 @@ def forward_loop(m):
196145
# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
197146
if pruned_ffn_div == 4:
198147
# Layer scores
199-
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1437832713127136, abs=1e-5)
200-
assert pruning_scores["layer_scores"][2] == pytest.approx(1.792158305644989, abs=1e-5)
148+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.0868452191352844, abs=1e-5)
149+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-5)
201150

202151
# Validate decoder.layers.0.mlp activations
203152
mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]
204-
assert mlp_0_acts.min().item() == pytest.approx(0.0011843212, abs=1e-5)
205-
assert mlp_0_acts.max().item() == pytest.approx(1.0846971273, abs=1e-5)
206-
assert mlp_0_acts.mean().item() == pytest.approx(0.0535472594, abs=1e-5)
153+
assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-5)
154+
assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-5)
155+
assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-5)
207156

208157
# Validate decoder.layers.1.mlp activations
209158
mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]
210-
assert mlp_1_acts.min().item() == pytest.approx(0.0002450741, abs=1e-5)
211-
assert mlp_1_acts.max().item() == pytest.approx(1.1014972925, abs=1e-5)
212-
assert mlp_1_acts.mean().item() == pytest.approx(0.0904172808, abs=1e-5)
159+
assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-5)
160+
assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-5)
161+
assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-5)
213162

214163
# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
215164
elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
216165
# Layer scores
217-
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1119985580444336, abs=1e-5)
218-
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7729830741882324, abs=1e-5)
166+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1415508985519409, abs=1e-5)
167+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7198008894920349, abs=1e-5)
219168

220169
# Validate decoder.layers.0.self_attention activations
221170
assert "decoder.layers.0.self_attention" in rank_0_activations
222171
attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"]
223172
assert attn_0_acts.shape == torch.Size([256])
224-
assert attn_0_acts.min().item() == pytest.approx(0.03729403391480446, abs=1e-5)
225-
assert attn_0_acts.max().item() == pytest.approx(0.3653244972229004, abs=1e-5)
226-
assert attn_0_acts.mean().item() == pytest.approx(0.15008458495140076, abs=1e-5)
173+
assert attn_0_acts.min().item() == pytest.approx(0.0409194342792034, abs=1e-5)
174+
assert attn_0_acts.max().item() == pytest.approx(0.5261313319206238, abs=1e-5)
175+
assert attn_0_acts.mean().item() == pytest.approx(0.1613342612981796, abs=1e-5)
227176

228177
# Validate decoder.layers.1.self_attention activations
229178
assert "decoder.layers.1.self_attention" in rank_0_activations
230179
attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"]
231180
assert attn_1_acts.shape == torch.Size([256])
232-
assert attn_1_acts.min().item() == pytest.approx(0.140824556350708, abs=1e-5)
233-
assert attn_1_acts.max().item() == pytest.approx(1.0845409631729126, abs=1e-5)
234-
assert attn_1_acts.mean().item() == pytest.approx(0.4730667173862457, abs=1e-5)
181+
assert attn_1_acts.min().item() == pytest.approx(0.1189328655600548, abs=1e-5)
182+
assert attn_1_acts.max().item() == pytest.approx(1.3832759857177734, abs=1e-5)
183+
assert attn_1_acts.mean().item() == pytest.approx(0.4782669544219971, abs=1e-5)
235184

236185
# Assert weights are pruned correctly
237186
for layer in model.decoder.layers:
@@ -295,14 +244,14 @@ def forward_loop(m):
295244
[
296245
# MHA - pruned ffn/4
297246
(8, 8, "squared_relu", "LayerNorm", 4, 1, 1, 1, 1, False, "rope", False, False),
298-
# # GQA - pruned attention/2
299-
# (8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
300-
# # GQA - pruned hidden_size/4
301-
# (8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
302-
# # MHA - pruned num_layers/2
303-
# (8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
304-
# # GQA - pruned all/2, uneven pp
305-
# (8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
247+
# GQA - pruned attention/2
248+
(8, 4, "squared_relu", "RMSNorm", 1, 2, 2, 1, 1, False, "rope", False, False),
249+
# GQA - pruned hidden_size/4
250+
(8, 4, "swiglu", "RMSNorm", 1, 1, 1, 4, 1, False, "rope", True, False),
251+
# MHA - pruned num_layers/2
252+
(8, 8, "swiglu", "LayerNorm", 1, 1, 1, 1, 2, False, "rope", False, False),
253+
# GQA - pruned all/2, uneven pp
254+
(8, 4, "swiglu", "RMSNorm", 2, 2, 2, 2, 2, True, "yarn", False, True),
306255
],
307256
)
308257
def test_mcore_gpt_pruning(

0 commit comments

Comments
 (0)