66from pydantic import model_validator
77from pydantic_settings import BaseSettings , SettingsConfigDict
88from torch .distributed .device_mesh import init_device_mesh
9+ from transformers import AutoTokenizer
910
1011from lm_saes .activation .factory import ActivationFactory
1112from lm_saes .activation .writer import ActivationWriter
@@ -256,6 +257,9 @@ class TrainSAESettings(BaseSettings):
256257 mongo : Optional [MongoDBConfig ] = None
257258 """Configuration for MongoDB"""
258259
260+ model_name : Optional [str ] = None
261+ """Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices."""
262+
259263
260264def train_sae (settings : TrainSAESettings ) -> None :
261265 """Train a SAE model.
@@ -276,9 +280,24 @@ def train_sae(settings: TrainSAESettings) -> None:
276280 activation_factory = ActivationFactory (settings .activation_factory )
277281 activations_stream = activation_factory .process ()
278282 initializer = Initializer (settings .initializer )
279- sae = initializer .initialize_sae_from_config (
280- settings .sae , activation_stream = activations_stream , device_mesh = device_mesh
281- )
283+
284+ if settings .sae .sae_type == "mixcoder" :
285+ assert settings .model_name is not None , "Model name is required for mixcoder SAE"
286+ tokenizer = AutoTokenizer .from_pretrained (settings .model_name , trust_remote_code = True )
287+ mixcoder_settings = {
288+ "model_name" : settings .model_name ,
289+ "tokenizer" : tokenizer ,
290+ }
291+ sae = initializer .initialize_sae_from_config (
292+ settings .sae ,
293+ activation_stream = activations_stream ,
294+ device_mesh = device_mesh ,
295+ mixcoder_settings = mixcoder_settings ,
296+ )
297+ else :
298+ sae = initializer .initialize_sae_from_config (
299+ settings .sae , activation_stream = activations_stream , device_mesh = device_mesh
300+ )
282301
283302 wandb_logger = (
284303 wandb .init (
@@ -289,7 +308,8 @@ def train_sae(settings: TrainSAESettings) -> None:
289308 settings = wandb .Settings (x_disable_stats = True ),
290309 mode = os .getenv ("WANDB_MODE" , "online" ),
291310 )
292- if settings .wandb is not None and (device_mesh is None or device_mesh .get_rank () == 0 ) else None
311+ if settings .wandb is not None and (device_mesh is None or device_mesh .get_rank () == 0 )
312+ else None
293313 )
294314 if wandb_logger is not None :
295315 wandb_logger .watch (sae , log = "all" )
0 commit comments