Skip to content

Commit daeb92b

Browse files
authored
feat(runner): support mixcoder training (#78)
* fix(runner): support mixcoder training * fix(runner): support mixcoder training
1 parent 6af18e3 commit daeb92b

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/lm_saes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FeatureAnalyzerConfig,
1212
InitializerConfig,
1313
LanguageModelConfig,
14+
MixCoderConfig,
1415
MongoDBConfig,
1516
SAEConfig,
1617
TrainerConfig,
@@ -54,4 +55,5 @@
5455
"FeatureAnalyzerConfig",
5556
"MongoDBConfig",
5657
"MongoClient",
58+
"MixCoderConfig",
5759
]

src/lm_saes/runner.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pydantic import model_validator
77
from pydantic_settings import BaseSettings, SettingsConfigDict
88
from torch.distributed.device_mesh import init_device_mesh
9+
from transformers import AutoTokenizer
910

1011
from lm_saes.activation.factory import ActivationFactory
1112
from lm_saes.activation.writer import ActivationWriter
@@ -256,6 +257,9 @@ class TrainSAESettings(BaseSettings):
256257
mongo: Optional[MongoDBConfig] = None
257258
"""Configuration for MongoDB"""
258259

260+
model_name: Optional[str] = None
261+
"""Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices."""
262+
259263

260264
def train_sae(settings: TrainSAESettings) -> None:
261265
"""Train a SAE model.
@@ -276,9 +280,24 @@ def train_sae(settings: TrainSAESettings) -> None:
276280
activation_factory = ActivationFactory(settings.activation_factory)
277281
activations_stream = activation_factory.process()
278282
initializer = Initializer(settings.initializer)
279-
sae = initializer.initialize_sae_from_config(
280-
settings.sae, activation_stream=activations_stream, device_mesh=device_mesh
281-
)
283+
284+
if settings.sae.sae_type == "mixcoder":
285+
assert settings.model_name is not None, "Model name is required for mixcoder SAE"
286+
tokenizer = AutoTokenizer.from_pretrained(settings.model_name, trust_remote_code=True)
287+
mixcoder_settings = {
288+
"model_name": settings.model_name,
289+
"tokenizer": tokenizer,
290+
}
291+
sae = initializer.initialize_sae_from_config(
292+
settings.sae,
293+
activation_stream=activations_stream,
294+
device_mesh=device_mesh,
295+
mixcoder_settings=mixcoder_settings,
296+
)
297+
else:
298+
sae = initializer.initialize_sae_from_config(
299+
settings.sae, activation_stream=activations_stream, device_mesh=device_mesh
300+
)
282301

283302
wandb_logger = (
284303
wandb.init(
@@ -289,7 +308,8 @@ def train_sae(settings: TrainSAESettings) -> None:
289308
settings=wandb.Settings(x_disable_stats=True),
290309
mode=os.getenv("WANDB_MODE", "online"),
291310
)
292-
if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) else None
311+
if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0)
312+
else None
293313
)
294314
if wandb_logger is not None:
295315
wandb_logger.watch(sae, log="all")

0 commit comments

Comments
 (0)