@@ -274,9 +274,6 @@ def _setup(self):
274274
275275 def _estimate_importance (self ) -> TracedHp .Importance :
276276 """Return the activation magnitude-based importance of the ffn_hidden_size."""
277- assert self ._activation_hook ._activations is not None , (
278- "No activations collected for importance estimation."
279- )
280277 return self ._activation_hook .accumulate ()
281278
282279 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
@@ -607,9 +604,6 @@ def _setup(self):
607604
608605 def _estimate_all_head_importance (self ) -> TracedHp .Importance :
609606 """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
610- assert self ._activation_hook ._activations is not None , (
611- "No activations collected for importance estimation."
612- )
613607 # Convert squared sum to L2 norm
614608 scores = self ._activation_hook .accumulate ()
615609 attn_head_importance = torch .linalg .vector_norm (
@@ -625,9 +619,6 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
625619
626620 def _estimate_query_group_importance (self ) -> TracedHp .Importance :
627621 """Return the importance of the ``num_query_groups`` hparam."""
628- assert self ._activation_hook ._activations is not None , (
629- "No activations collected for importance estimation."
630- )
631622 # Convert squared sum to L2 norm
632623 scores = self ._activation_hook .accumulate ()
633624 group_importance = torch .linalg .vector_norm (
0 commit comments