Skip to content

Commit 1cbcd25

Browse files
Frankstein73dest1n1s
authored andcommitted
feat(mixcoder): implemented mixcoder
1 parent 4faf9f6 commit 1cbcd25

File tree

12 files changed

+780
-75
lines changed

12 files changed

+780
-75
lines changed

src/lm_saes/config.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,11 @@ class CrossCoderConfig(BaseSAEConfig):
104104

105105
class MixCoderConfig(BaseSAEConfig):
106106
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder"
107-
d_single_modal: int
108-
d_shared: int
109-
n_modalities: int = 2
107+
modalities: dict[str, int]
110108

111109
@property
112110
def d_sae(self) -> int:
113-
return self.d_single_modal * self.n_modalities + self.d_shared
111+
return sum(self.modalities.values())
114112

115113

116114
class InitializerConfig(BaseConfig):
@@ -131,7 +129,7 @@ class TrainerConfig(BaseConfig):
131129
k_warmup_steps: int | float = 0.1
132130
use_batch_norm_mse: bool = True
133131

134-
lr: float = 0.0004
132+
lr: float | dict[str, float] = 0.0004
135133
betas: Tuple[float, float] = (0.9, 0.999)
136134
lr_scheduler_name: Literal[
137135
"constant",

src/lm_saes/crosscoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def encode(
4040
Float[torch.Tensor, "batch seq_len d_model"],
4141
],
4242
return_hidden_pre: Literal[False] = False,
43+
**kwargs,
4344
) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]: ...
4445

4546
@overload
@@ -50,6 +51,7 @@ def encode(
5051
Float[torch.Tensor, "batch seq_len d_model"],
5152
],
5253
return_hidden_pre: Literal[True],
54+
**kwargs,
5355
) -> tuple[
5456
Union[
5557
Float[torch.Tensor, "batch d_sae"],
@@ -68,6 +70,7 @@ def encode(
6870
Float[torch.Tensor, "batch seq_len d_model"],
6971
],
7072
return_hidden_pre: bool = False,
73+
**kwargs,
7174
) -> Union[
7275
Float[torch.Tensor, "batch d_sae"],
7376
Float[torch.Tensor, "batch seq_len d_sae"],
@@ -130,6 +133,7 @@ def compute_loss(
130133
use_batch_norm_mse: bool = False,
131134
lp: int = 1,
132135
return_aux_data: Literal[True] = True,
136+
**kwargs,
133137
) -> tuple[
134138
Float[torch.Tensor, " batch"],
135139
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
@@ -143,6 +147,7 @@ def compute_loss(
143147
use_batch_norm_mse: bool = False,
144148
lp: int = 1,
145149
return_aux_data: Literal[False],
150+
**kwargs,
146151
) -> Float[torch.Tensor, " batch"]: ...
147152

148153
def compute_loss(
@@ -159,6 +164,7 @@ def compute_loss(
159164
use_batch_norm_mse: bool = False,
160165
lp: int = 1,
161166
return_aux_data: bool = True,
167+
**kwargs,
162168
) -> Union[
163169
Float[torch.Tensor, " batch"],
164170
tuple[

src/lm_saes/evaluator.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from transformer_lens import HookedTransformer
77
from wandb.sdk.wandb_run import Run
88

9-
from lm_saes.config import EvalConfig, MixCoderConfig, SAEConfig
10-
from lm_saes.mixcoder import MixCoder
9+
from lm_saes.config import EvalConfig
1110
from lm_saes.sae import SparseAutoEncoder
1211

1312

@@ -55,11 +54,12 @@ def log_metric(metric: str, value: float) -> None:
5554
# 2. Get activations and compute reconstructions
5655
activation_in = activation_dict[sae.cfg.hook_point_in][useful_token_mask]
5756
activation_out = activation_dict[sae.cfg.hook_point_out][useful_token_mask]
58-
feature_acts = sae.encode(activation_in)
57+
tokens = activation_dict["tokens"][useful_token_mask]
58+
feature_acts = sae.encode(activation_in, tokens=tokens)
5959
reconstructed = (
6060
log_info.pop("reconstructed")[useful_token_mask]
6161
if "reconstructed" in log_info
62-
else sae.forward(activation_in)
62+
else sae.forward(activation_in, tokens=tokens)
6363
)
6464

6565
# 3. Compute sparsity metrics
@@ -156,15 +156,7 @@ def _evaluate_tokens(
156156
names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out],
157157
return_cache_object=False,
158158
)
159-
reconstructed_activations: Tensor | None = None
160-
if isinstance(sae.cfg, SAEConfig):
161-
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in])
162-
163-
elif isinstance(sae.cfg, MixCoderConfig):
164-
assert isinstance(sae, MixCoder)
165-
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids)
166-
167-
assert reconstructed_activations is not None
159+
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids)
168160

169161
def replace_hook(activations: Tensor, hook_point: str) -> Tensor:
170162
return torch.where(useful_token_mask, reconstructed_activations, activations)

src/lm_saes/initializer.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict, Iterable, List
1+
import warnings
2+
from typing import Any, Dict, Iterable, List
23

34
import torch
45
from torch import Tensor
@@ -10,28 +11,30 @@
1011
parallelize_module,
1112
)
1213

13-
from lm_saes.config import BaseSAEConfig, InitializerConfig, SAEConfig
14+
from lm_saes.config import BaseSAEConfig, InitializerConfig
15+
from lm_saes.mixcoder import MixCoder
1416
from lm_saes.sae import SparseAutoEncoder
15-
from lm_saes.utils.misc import calculate_activation_norm
17+
from lm_saes.utils.misc import calculate_activation_norm, get_modality_indices
1618

1719

1820
class Initializer:
1921
def __init__(self, cfg: InitializerConfig):
2022
self.cfg = cfg
2123

2224
@torch.no_grad()
23-
def initialize_parameters(self, sae: SparseAutoEncoder):
25+
def initialize_parameters(self, sae: SparseAutoEncoder, mixcoder_settings: dict[str, Any] | None = None):
2426
"""Initialize the parameters of the SAE.
2527
Only used when the state is "training" to initialize sae.
2628
"""
27-
torch.nn.init.kaiming_uniform_(sae.encoder.weight)
28-
torch.nn.init.kaiming_uniform_(sae.decoder.weight)
29-
torch.nn.init.zeros_(sae.encoder.bias)
30-
if sae.cfg.use_decoder_bias:
31-
torch.nn.init.zeros_(sae.decoder.bias)
32-
if sae.cfg.use_glu_encoder:
33-
torch.nn.init.kaiming_uniform_(sae.encoder_glu.weight)
34-
torch.nn.init.zeros_(sae.encoder_glu.bias)
29+
30+
if sae.cfg.sae_type == "mixcoder":
31+
assert mixcoder_settings is not None
32+
assert "model_name" in mixcoder_settings and "tokenizer" in mixcoder_settings
33+
modality_indices = get_modality_indices(mixcoder_settings["tokenizer"], mixcoder_settings["model_name"])
34+
sae.init_parameters(modality_indices=modality_indices)
35+
36+
else:
37+
sae.init_parameters()
3538

3639
if self.cfg.init_decoder_norm:
3740
sae.set_decoder_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True)
@@ -48,14 +51,18 @@ def initialize_parameters(self, sae: SparseAutoEncoder):
4851
def initialize_tensor_parallel(self, sae: SparseAutoEncoder, device_mesh: DeviceMesh | None = None):
4952
if not device_mesh or device_mesh["model"].size(0) == 1:
5053
return sae
51-
sae.device_mesh = device_mesh
52-
plan = {
53-
"encoder": ColwiseParallel(output_layouts=Replicate()),
54-
"decoder": RowwiseParallel(input_layouts=Replicate()),
55-
}
56-
if sae.cfg.use_glu_encoder:
57-
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
58-
sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore
54+
if sae.cfg.sae_type == "sae":
55+
sae.device_mesh = device_mesh
56+
plan = {
57+
"encoder": ColwiseParallel(output_layouts=Replicate()),
58+
"decoder": RowwiseParallel(input_layouts=Replicate()),
59+
}
60+
if sae.cfg.use_glu_encoder:
61+
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
62+
sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore
63+
64+
elif sae.cfg.sae_type == "mixcoder":
65+
warnings.warn("MixCoder is not supported for tensor parallel initialization.")
5966
return sae
6067

6168
@torch.no_grad()
@@ -67,6 +74,7 @@ def initialization_search(self, sae: SparseAutoEncoder, activation_batch: Dict[s
6774
activation_batch[sae.cfg.hook_point_in],
6875
activation_batch[sae.cfg.hook_point_out],
6976
)
77+
tokens = activation_batch["tokens"]
7078
if self.cfg.init_decoder_norm is None:
7179
assert sae.cfg.sparsity_include_decoder_norm, "Decoder norm must be included in sparsity loss"
7280
if not self.cfg.init_encoder_with_decoder_transpose or sae.cfg.hook_point_in != sae.cfg.hook_point_out:
@@ -80,7 +88,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
8088
sae.init_encoder_with_decoder_transpose()
8189
if sae.cfg.sae_type == "crosscoder":
8290
sae.initialize_with_same_weight_across_layers()
83-
mse = sae.compute_loss(activation_batch)[1][0]["l_rec"].mean().item()
91+
mse = sae.compute_loss(activation_batch, tokens=tokens)[1][0]["l_rec"].mean().item()
8492
losses[norm] = mse
8593
best_norm = min(losses, key=losses.get) # type: ignore
8694
return best_norm
@@ -97,7 +105,8 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
97105

98106
sae.set_decoder_to_fixed_norm(best_norm_fine_grained, force_exact=True)
99107

100-
if self.cfg.bias_init_method == "geometric_median":
108+
if self.cfg.bias_init_method == "geometric_median" and sae.cfg.sae_type != "mixcoder":
109+
# TODO: add support for MixCoder
101110
sae.decoder.bias.data = (
102111
sae.compute_norm_factor(activation_out, hook_point=sae.cfg.hook_point_out) * activation_out
103112
).mean(0)
@@ -116,9 +125,15 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
116125

117126
@torch.no_grad()
118127
def initialize_jump_relu_threshold(self, sae: SparseAutoEncoder, activation_batch: Dict[str, Tensor]):
128+
# TODO: add support for MixCoder
129+
if sae.cfg.sae_type == "mixcoder":
130+
warnings.warn("MixCoder is not supported for jump_relu_threshold initialization.")
131+
return sae
132+
119133
activation_in = activation_batch[sae.cfg.hook_point_in]
134+
tokens = activation_batch["tokens"]
120135
batch_size = activation_in.size(0)
121-
_, hidden_pre = sae.encode(activation_in, return_hidden_pre=True)
136+
_, hidden_pre = sae.encode(activation_in, return_hidden_pre=True, tokens=tokens)
122137
hidden_pre = torch.clamp(hidden_pre, min=0.0)
123138
hidden_pre = hidden_pre.flatten()
124139
threshold = hidden_pre.topk(k=batch_size * sae.cfg.top_k).values[-1]
@@ -131,6 +146,7 @@ def initialize_sae_from_config(
131146
activation_stream: Iterable[dict[str, Tensor]] | None = None,
132147
activation_norm: dict[str, float] | None = None,
133148
device_mesh: DeviceMesh | None = None,
149+
mixcoder_settings: dict[str, Any] | None = None,
134150
):
135151
"""
136152
Initialize the SAE from the SAE config.
@@ -141,14 +157,16 @@ def initialize_sae_from_config(
141157
device_mesh (DeviceMesh | None): The device mesh.
142158
"""
143159
sae = None # type: ignore
144-
if isinstance(cfg, SAEConfig):
160+
if cfg.sae_type == "sae":
145161
sae = SparseAutoEncoder.from_config(cfg)
162+
elif cfg.sae_type == "mixcoder":
163+
sae = MixCoder.from_config(cfg)
146164
else:
147165
# TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
148166
pass
149167
if self.cfg.state == "training":
150168
if cfg.sae_pretrained_name_or_path is None:
151-
sae: SparseAutoEncoder = self.initialize_parameters(sae)
169+
sae: SparseAutoEncoder = self.initialize_parameters(sae, mixcoder_settings=mixcoder_settings)
152170
if sae.cfg.norm_activation == "dataset-wise":
153171
if activation_norm is None:
154172
assert (
@@ -179,7 +197,8 @@ def initialize_sae_from_config(
179197
), "Activation iterator must be provided for jump_relu_threshold initialization"
180198
activation_batch = next(iter(activation_stream))
181199
self.initialize_jump_relu_threshold(sae, activation_batch)
182-
sae.cfg.act_fn = "jumprelu"
200+
if cfg.sae_type != "mixcoder": # TODO: add support for MixCoder
201+
sae.cfg.act_fn = "jumprelu"
183202

184203
sae = self.initialize_tensor_parallel(sae, device_mesh)
185204
return sae

0 commit comments

Comments
 (0)