1- from typing import Dict , Iterable , List
1+ import warnings
2+ from typing import Any , Dict , Iterable , List
23
34import torch
45from torch import Tensor
1011 parallelize_module ,
1112)
1213
13- from lm_saes .config import BaseSAEConfig , InitializerConfig , SAEConfig
14+ from lm_saes .config import BaseSAEConfig , InitializerConfig
15+ from lm_saes .mixcoder import MixCoder
1416from lm_saes .sae import SparseAutoEncoder
15- from lm_saes .utils .misc import calculate_activation_norm
17+ from lm_saes .utils .misc import calculate_activation_norm , get_modality_indices
1618
1719
1820class Initializer :
1921 def __init__ (self , cfg : InitializerConfig ):
2022 self .cfg = cfg
2123
2224 @torch .no_grad ()
23- def initialize_parameters (self , sae : SparseAutoEncoder ):
25+ def initialize_parameters (self , sae : SparseAutoEncoder , mixcoder_settings : dict [ str , Any ] | None = None ):
2426 """Initialize the parameters of the SAE.
2527 Only used when the state is "training" to initialize sae.
2628 """
27- torch .nn .init .kaiming_uniform_ (sae .encoder .weight )
28- torch .nn .init .kaiming_uniform_ (sae .decoder .weight )
29- torch .nn .init .zeros_ (sae .encoder .bias )
30- if sae .cfg .use_decoder_bias :
31- torch .nn .init .zeros_ (sae .decoder .bias )
32- if sae .cfg .use_glu_encoder :
33- torch .nn .init .kaiming_uniform_ (sae .encoder_glu .weight )
34- torch .nn .init .zeros_ (sae .encoder_glu .bias )
29+
30+ if sae .cfg .sae_type == "mixcoder" :
31+ assert mixcoder_settings is not None
32+ assert "model_name" in mixcoder_settings and "tokenizer" in mixcoder_settings
33+ modality_indices = get_modality_indices (mixcoder_settings ["tokenizer" ], mixcoder_settings ["model_name" ])
34+ sae .init_parameters (modality_indices = modality_indices )
35+
36+ else :
37+ sae .init_parameters ()
3538
3639 if self .cfg .init_decoder_norm :
3740 sae .set_decoder_to_fixed_norm (self .cfg .init_decoder_norm , force_exact = True )
@@ -48,14 +51,18 @@ def initialize_parameters(self, sae: SparseAutoEncoder):
4851 def initialize_tensor_parallel (self , sae : SparseAutoEncoder , device_mesh : DeviceMesh | None = None ):
4952 if not device_mesh or device_mesh ["model" ].size (0 ) == 1 :
5053 return sae
51- sae .device_mesh = device_mesh
52- plan = {
53- "encoder" : ColwiseParallel (output_layouts = Replicate ()),
54- "decoder" : RowwiseParallel (input_layouts = Replicate ()),
55- }
56- if sae .cfg .use_glu_encoder :
57- plan ["encoder_glu" ] = ColwiseParallel (output_layouts = Replicate ())
58- sae = parallelize_module (sae , device_mesh = device_mesh ["model" ], parallelize_plan = plan ) # type: ignore
54+ if sae .cfg .sae_type == "sae" :
55+ sae .device_mesh = device_mesh
56+ plan = {
57+ "encoder" : ColwiseParallel (output_layouts = Replicate ()),
58+ "decoder" : RowwiseParallel (input_layouts = Replicate ()),
59+ }
60+ if sae .cfg .use_glu_encoder :
61+ plan ["encoder_glu" ] = ColwiseParallel (output_layouts = Replicate ())
62+ sae = parallelize_module (sae , device_mesh = device_mesh ["model" ], parallelize_plan = plan ) # type: ignore
63+
64+ elif sae .cfg .sae_type == "mixcoder" :
65+ warnings .warn ("MixCoder is not supported for tensor parallel initialization." )
5966 return sae
6067
6168 @torch .no_grad ()
@@ -67,6 +74,7 @@ def initialization_search(self, sae: SparseAutoEncoder, activation_batch: Dict[s
6774 activation_batch [sae .cfg .hook_point_in ],
6875 activation_batch [sae .cfg .hook_point_out ],
6976 )
77+ tokens = activation_batch ["tokens" ]
7078 if self .cfg .init_decoder_norm is None :
7179 assert sae .cfg .sparsity_include_decoder_norm , "Decoder norm must be included in sparsity loss"
7280 if not self .cfg .init_encoder_with_decoder_transpose or sae .cfg .hook_point_in != sae .cfg .hook_point_out :
@@ -80,7 +88,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
8088 sae .init_encoder_with_decoder_transpose ()
8189 if sae .cfg .sae_type == "crosscoder" :
8290 sae .initialize_with_same_weight_across_layers ()
83- mse = sae .compute_loss (activation_batch )[1 ][0 ]["l_rec" ].mean ().item ()
91+ mse = sae .compute_loss (activation_batch , tokens = tokens )[1 ][0 ]["l_rec" ].mean ().item ()
8492 losses [norm ] = mse
8593 best_norm = min (losses , key = losses .get ) # type: ignore
8694 return best_norm
@@ -97,7 +105,8 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
97105
98106 sae .set_decoder_to_fixed_norm (best_norm_fine_grained , force_exact = True )
99107
100- if self .cfg .bias_init_method == "geometric_median" :
108+ if self .cfg .bias_init_method == "geometric_median" and sae .cfg .sae_type != "mixcoder" :
109+ # TODO: add support for MixCoder
101110 sae .decoder .bias .data = (
102111 sae .compute_norm_factor (activation_out , hook_point = sae .cfg .hook_point_out ) * activation_out
103112 ).mean (0 )
@@ -116,9 +125,15 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
116125
117126 @torch .no_grad ()
118127 def initialize_jump_relu_threshold (self , sae : SparseAutoEncoder , activation_batch : Dict [str , Tensor ]):
128+ # TODO: add support for MixCoder
129+ if sae .cfg .sae_type == "mixcoder" :
130+ warnings .warn ("MixCoder is not supported for jump_relu_threshold initialization." )
131+ return sae
132+
119133 activation_in = activation_batch [sae .cfg .hook_point_in ]
134+ tokens = activation_batch ["tokens" ]
120135 batch_size = activation_in .size (0 )
121- _ , hidden_pre = sae .encode (activation_in , return_hidden_pre = True )
136+ _ , hidden_pre = sae .encode (activation_in , return_hidden_pre = True , tokens = tokens )
122137 hidden_pre = torch .clamp (hidden_pre , min = 0.0 )
123138 hidden_pre = hidden_pre .flatten ()
124139 threshold = hidden_pre .topk (k = batch_size * sae .cfg .top_k ).values [- 1 ]
@@ -131,6 +146,7 @@ def initialize_sae_from_config(
131146 activation_stream : Iterable [dict [str , Tensor ]] | None = None ,
132147 activation_norm : dict [str , float ] | None = None ,
133148 device_mesh : DeviceMesh | None = None ,
149+ mixcoder_settings : dict [str , Any ] | None = None ,
134150 ):
135151 """
136152 Initialize the SAE from the SAE config.
@@ -141,14 +157,16 @@ def initialize_sae_from_config(
141157 device_mesh (DeviceMesh | None): The device mesh.
142158 """
143159 sae = None # type: ignore
144- if isinstance ( cfg , SAEConfig ) :
160+ if cfg . sae_type == "sae" :
145161 sae = SparseAutoEncoder .from_config (cfg )
162+ elif cfg .sae_type == "mixcoder" :
163+ sae = MixCoder .from_config (cfg )
146164 else :
147165 # TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
148166 pass
149167 if self .cfg .state == "training" :
150168 if cfg .sae_pretrained_name_or_path is None :
151- sae : SparseAutoEncoder = self .initialize_parameters (sae )
169+ sae : SparseAutoEncoder = self .initialize_parameters (sae , mixcoder_settings = mixcoder_settings )
152170 if sae .cfg .norm_activation == "dataset-wise" :
153171 if activation_norm is None :
154172 assert (
@@ -179,7 +197,8 @@ def initialize_sae_from_config(
179197 ), "Activation iterator must be provided for jump_relu_threshold initialization"
180198 activation_batch = next (iter (activation_stream ))
181199 self .initialize_jump_relu_threshold (sae , activation_batch )
182- sae .cfg .act_fn = "jumprelu"
200+ if cfg .sae_type != "mixcoder" : # TODO: add support for MixCoder
201+ sae .cfg .act_fn = "jumprelu"
183202
184203 sae = self .initialize_tensor_parallel (sae , device_mesh )
185204 return sae
0 commit comments