Skip to content

Commit 8e761ba

Browse files
Update doc strings
Signed-off-by: Daniel Korzekwa <[email protected]>
1 parent ded5944 commit 8e761ba

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

modelopt/torch/nas/plugins/megatron_hooks.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,14 @@ def __init__(self, max_size: int | None = None):
116116
def __call__(
117117
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
118118
) -> None:
119-
"""Accumulate activation statistics from the forward pass."""
119+
"""Accumulate activation statistics from the forward pass.
120+
121+
Args:
122+
module: The module this hook is registered on.
123+
args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size]
124+
(Megatron sequence-first format).
125+
output: Output tensor from the module's forward pass.
126+
"""
120127
# Gather input [seq_len, batch_size, hidden_size] over all TP regions
121128
# NOTE: This is not used at the moment since we restrict to TP=1
122129
input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach()
@@ -241,14 +248,16 @@ def __init__(
241248
self.epsilon = 1e-8
242249

243250
def __call__(
244-
self, module: nn.Module, args: tuple[torch.Tensor], output: torch.Tensor | tuple
251+
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor | tuple
245252
) -> None:
246253
"""Compute channel contributions and prune channels according to schedule.
247254
248255
Args:
249256
module: The module this hook is registered on.
250-
args: Tuple with input tensor of shape (B, T, I).
251-
output: Output tensor of shape (B, T, E), or tuple (output_tensor, bias) for parallel layers.
257+
args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels]
258+
(PyTorch batch-first format).
259+
output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias)
260+
for parallel layers.
252261
"""
253262
# Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear)
254263
# TODO: Consider better design to handle RowParallelLinear and nn.Linear

0 commit comments

Comments
 (0)