@@ -116,7 +116,14 @@ def __init__(self, max_size: int | None = None):
116116 def __call__ (
117117 self , module : nn .Module , args : tuple [torch .Tensor , ...], output : torch .Tensor
118118 ) -> None :
119- """Accumulate activation statistics from the forward pass."""
119+ """Accumulate activation statistics from the forward pass.
120+
121+ Args:
122+ module: The module this hook is registered on.
123+ args: Tuple of input tensors. args[0] expected shape: [seq_len, batch_size, hidden_size]
124+ (Megatron sequence-first format).
125+ output: Output tensor from the module's forward pass.
126+ """
120127 # Gather input [seq_len, batch_size, hidden_size] over all TP regions
121128 # NOTE: This is not used at the moment since we restrict to TP=1
122129 input_tensor = gather_from_tensor_model_parallel_region (args [0 ]).detach ()
@@ -241,14 +248,16 @@ def __init__(
241248 self .epsilon = 1e-8
242249
243250 def __call__ (
244- self , module : nn .Module , args : tuple [torch .Tensor ], output : torch .Tensor | tuple
251+ self , module : nn .Module , args : tuple [torch .Tensor , ... ], output : torch .Tensor | tuple
245252 ) -> None :
246253 """Compute channel contributions and prune channels according to schedule.
247254
248255 Args:
249256 module: The module this hook is registered on.
250- args: Tuple with input tensor of shape (B, T, I).
251- output: Output tensor of shape (B, T, E), or tuple (output_tensor, bias) for parallel layers.
257+ args: Tuple with single input tensor. args[0] expected shape: [batch_size, seq_len, input_channels]
258+ (PyTorch batch-first format).
259+ output: Output tensor of shape [batch_size, seq_len, output_channels], or tuple (output_tensor, bias)
260+ for parallel layers.
252261 """
253262 # Handle case where output is a tuple (e.g., from ColumnParallelLinear/RowParallelLinear)
254263 # TODO: Consider better design to handle RowParallelLinear and nn.Linear
0 commit comments