Skip to content

Commit 65c956a

Browse files
committed
fix(trainer): correct token count calculation for 2D activation in LORSA training
1 parent 9518957 commit 65c956a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/lm_saes/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def _initialize_trainer(
213213
activation_stream: Iterable[dict[str, Tensor]],
214214
wandb_logger: Run | None = None,
215215
):
216-
bs = batch_size(next(iter(activation_stream)))
216+
batch = next(iter(activation_stream))
217+
bs = batch["tokens"].numel() if batch.get("mask") is None else int(item(batch["mask"].sum()))
217218
self.total_training_steps = self.cfg.total_training_tokens // bs
218219

219220
def calculate_warmup_steps(warmup_steps: float | int) -> int:

0 commit comments

Comments
 (0)