Skip to content

Commit 322af3b

Browse files
committed
feat: support resuming wandb run from training checkpoint
- Add wandb_run_id and wandb_resume config options - Save wandb run id when saving checkpoint - Load trainer from checkpoint when from_pretrained_path is set
1 parent a0e7838 commit 322af3b

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ dev = [
8989
"qwen-vl-utils>=0.0.10",
9090
"tabulate>=0.9.0",
9191
"gradio>=5.34.0",
92+
"sqlalchemy>=2.0.44",
93+
"apscheduler>=3.11.1",
9294
]
9395
docs = [
9496
"mkdocs-gen-files>=0.5.0",

src/lm_saes/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,8 @@ class WandbConfig(BaseConfig):
782782
wandb_project: str = "gpt2-sae-training"
783783
exp_name: str | None = None
784784
wandb_entity: str | None = None
785+
wandb_run_id: str | None = None
786+
wandb_resume: Literal["allow", "must", "never", "auto"] = "never"
785787

786788

787789
class MongoDBConfig(BaseConfig):

src/lm_saes/runners/train.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,23 @@ def train_sae(settings: TrainSAESettings) -> None:
168168
entity=settings.wandb.wandb_entity,
169169
settings=wandb.Settings(x_disable_stats=True),
170170
mode=os.getenv("WANDB_MODE", "online"), # type: ignore
171+
resume=settings.wandb.wandb_resume,
172+
id=settings.wandb.wandb_run_id,
171173
)
172174
if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
173175
else None
174176
)
175-
176177
sae = initializer.initialize_sae_from_config(
177178
settings.sae, activation_stream=activations_stream, device_mesh=device_mesh, wandb_logger=wandb_logger
178179
)
180+
if settings.trainer.from_pretrained_path is not None:
181+
trainer = Trainer.from_checkpoint(
182+
sae,
183+
settings.trainer.from_pretrained_path,
184+
)
185+
trainer.wandb_logger = wandb_logger
186+
else:
187+
trainer = Trainer(settings.trainer)
179188

180189
logger.info(f"SAE initialized: {type(sae).__name__}")
181190

@@ -186,17 +195,24 @@ def train_sae(settings: TrainSAESettings) -> None:
186195
eval_fn = (lambda x: None) if settings.eval else None
187196

188197
logger.info("Starting training")
189-
trainer = Trainer(settings.trainer)
190-
sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
191-
trainer.fit(sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger)
192198

193-
logger.info("Training completed, saving model")
194-
sae.save_pretrained(
195-
save_path=settings.trainer.exp_result_path,
196-
sae_name=settings.sae_name,
197-
sae_series=settings.sae_series,
198-
mongo_client=mongo_client,
199+
sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
200+
end_of_stream = trainer.fit(
201+
sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
199202
)
203+
logger.info("Training completed, saving model")
204+
if end_of_stream:
205+
trainer.save_checkpoint(
206+
sae=sae,
207+
checkpoint_path=settings.trainer.exp_result_path,
208+
)
209+
else:
210+
sae.save_pretrained(
211+
save_path=settings.trainer.exp_result_path,
212+
sae_name=settings.sae_name,
213+
sae_series=settings.sae_series,
214+
mongo_client=mongo_client,
215+
)
200216

201217
if wandb_logger is not None:
202218
wandb_logger.finish()

src/lm_saes/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import math
23
import os
34
from pathlib import Path
@@ -88,11 +89,12 @@ def save_checkpoint(self, sae: AbstractSparseAutoEncoder, checkpoint_path: Path
8889
"checkpoint_thresholds": self.checkpoint_thresholds,
8990
"cfg": self.cfg,
9091
}
91-
9292
# Save trainer state
9393
trainer_path = checkpoint_dir / "trainer.pt"
9494
torch.save(trainer_state, trainer_path)
95-
95+
if self.wandb_logger is not None:
96+
with open(checkpoint_dir / "wandb_run_id.json", "w") as f:
97+
json.dump({"wandb_run_id": self.wandb_logger.id}, f)
9698
# Save optimizer state - handle distributed tensors
9799
if self.optimizer is not None:
98100
if sae.device_mesh is None:
@@ -479,14 +481,13 @@ def fit(
479481
with timer.time("evaluation"):
480482
eval_fn(sae)
481483

482-
self._maybe_save_sae_checkpoint(sae)
483484
with timer.time("scheduler_step"):
484485
self.scheduler.step()
485-
486486
self.cur_step += 1
487487
self.cur_tokens += (
488488
batch["tokens"].numel() if batch.get("mask") is None else int(item(batch["mask"].sum()))
489489
)
490+
self._maybe_save_sae_checkpoint(sae)
490491
if self.cur_tokens >= self.cfg.total_training_tokens:
491492
break
492493
except StopIteration:

0 commit comments

Comments
 (0)