We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9518957 commit 65c956aCopy full SHA for 65c956a
src/lm_saes/trainer.py
@@ -213,7 +213,8 @@ def _initialize_trainer(
213
activation_stream: Iterable[dict[str, Tensor]],
214
wandb_logger: Run | None = None,
215
):
216
- bs = batch_size(next(iter(activation_stream)))
+ batch = next(iter(activation_stream))
217
+ bs = batch["tokens"].numel() if batch.get("mask") is None else int(item(batch["mask"].sum()))
218
self.total_training_steps = self.cfg.total_training_tokens // bs
219
220
def calculate_warmup_steps(warmup_steps: float | int) -> int:
0 commit comments