We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f5baed2 commit 447a15dCopy full SHA for 447a15d
modelopt/torch/nas/plugins/megatron.py
@@ -594,7 +594,6 @@ def _setup(self):
594
max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels
595
activation_hook = MegatronL2NormHook(max_size=max_size)
596
self._register_temp_attribute("_activation_hook", activation_hook)
597
- # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
598
self.hook_handle = self.linear_proj.register_forward_hook(activation_hook)
599
# NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
600
# otherwise we would only have aggregated importance of heads per group.
0 commit comments