@@ -166,18 +166,18 @@ def forward_loop(m):
166166 print ("\n === TEST CASE 1 ===" )
167167 print (f"layer_scores[1] = { pruning_scores ['layer_scores' ][1 ]:.16f} " )
168168 print (f"layer_scores[2] = { pruning_scores ['layer_scores' ][2 ]:.16f} " )
169- assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.0868452191352844 , abs = 1e-5 )
170- assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7638601660728455 , abs = 1e-5 )
169+ assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.0868452191352844 , abs = 1e-3 )
170+ assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7638601660728455 , abs = 1e-3 )
171171
172172 # Validate decoder.layers.0.mlp activations
173173 mlp_0_acts = rank_0_activations ["decoder.layers.0.mlp" ]
174174 if rank == 0 :
175175 print (f"mlp_0_acts.min() = { mlp_0_acts .min ().item ():.16f} " )
176176 print (f"mlp_0_acts.max() = { mlp_0_acts .max ().item ():.16f} " )
177177 print (f"mlp_0_acts.mean() = { mlp_0_acts .mean ().item ():.16f} " )
178- assert mlp_0_acts .min ().item () == pytest .approx (0.0015609927941114 , abs = 1e-5 )
179- assert mlp_0_acts .max ().item () == pytest .approx (0.3844809532165527 , abs = 1e-5 )
180- assert mlp_0_acts .mean ().item () == pytest .approx (0.0629318505525589 , abs = 1e-5 )
178+ assert mlp_0_acts .min ().item () == pytest .approx (0.0015609927941114 , abs = 1e-3 )
179+ assert mlp_0_acts .max ().item () == pytest .approx (0.3844809532165527 , abs = 1e-3 )
180+ assert mlp_0_acts .mean ().item () == pytest .approx (0.0629318505525589 , abs = 1e-3 )
181181
182182 # Validate decoder.layers.1.mlp activations
183183 mlp_1_acts = rank_0_activations ["decoder.layers.1.mlp" ]
@@ -186,31 +186,31 @@ def forward_loop(m):
186186 print (f"mlp_1_acts.max() = { mlp_1_acts .max ().item ():.16f} " )
187187 print (f"mlp_1_acts.mean() = { mlp_1_acts .mean ().item ():.16f} " )
188188 print ("=" * 50 + "\n " )
189- assert mlp_1_acts .min ().item () == pytest .approx (0.0001484956446802 , abs = 1e-5 )
190- assert mlp_1_acts .max ().item () == pytest .approx (0.7835369110107422 , abs = 1e-5 )
191- assert mlp_1_acts .mean ().item () == pytest .approx (0.0926810950040817 , abs = 1e-5 )
189+ assert mlp_1_acts .min ().item () == pytest .approx (0.0001484956446802 , abs = 1e-3 )
190+ assert mlp_1_acts .max ().item () == pytest .approx (0.7835369110107422 , abs = 1e-3 )
191+ assert mlp_1_acts .mean ().item () == pytest .approx (0.0926810950040817 , abs = 1e-3 )
192192
193193 # Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
194194 elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1 :
195195 # Layer scores
196- assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.1415508985519409 , abs = 1e-5 )
197- assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7198008894920349 , abs = 1e-5 )
196+ assert pruning_scores ["layer_scores" ][1 ] == pytest .approx (2.1415508985519409 , abs = 1e-3 )
197+ assert pruning_scores ["layer_scores" ][2 ] == pytest .approx (1.7198008894920349 , abs = 1e-3 )
198198
199199 # Validate decoder.layers.0.self_attention activations
200200 assert "decoder.layers.0.self_attention" in rank_0_activations
201201 attn_0_acts = rank_0_activations ["decoder.layers.0.self_attention" ]
202202 assert attn_0_acts .shape == torch .Size ([256 ])
203- assert attn_0_acts .min ().item () == pytest .approx (0.0409194342792034 , abs = 1e-5 )
204- assert attn_0_acts .max ().item () == pytest .approx (0.5261313319206238 , abs = 1e-5 )
205- assert attn_0_acts .mean ().item () == pytest .approx (0.1613342612981796 , abs = 1e-5 )
203+ assert attn_0_acts .min ().item () == pytest .approx (0.0409194342792034 , abs = 1e-3 )
204+ assert attn_0_acts .max ().item () == pytest .approx (0.5261313319206238 , abs = 1e-3 )
205+ assert attn_0_acts .mean ().item () == pytest .approx (0.1613342612981796 , abs = 1e-3 )
206206
207207 # Validate decoder.layers.1.self_attention activations
208208 assert "decoder.layers.1.self_attention" in rank_0_activations
209209 attn_1_acts = rank_0_activations ["decoder.layers.1.self_attention" ]
210210 assert attn_1_acts .shape == torch .Size ([256 ])
211- assert attn_1_acts .min ().item () == pytest .approx (0.1189328655600548 , abs = 1e-5 )
212- assert attn_1_acts .max ().item () == pytest .approx (1.3832759857177734 , abs = 1e-5 )
213- assert attn_1_acts .mean ().item () == pytest .approx (0.4782669544219971 , abs = 1e-5 )
211+ assert attn_1_acts .min ().item () == pytest .approx (0.1189328655600548 , abs = 1e-3 )
212+ assert attn_1_acts .max ().item () == pytest .approx (1.3832759857177734 , abs = 1e-3 )
213+ assert attn_1_acts .mean ().item () == pytest .approx (0.4782669544219971 , abs = 1e-3 )
214214
215215 # Assert weights are pruned correctly
216216 for layer in model .decoder .layers :
0 commit comments