Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ class CrossCoderConfig(BaseSAEConfig):

class MixCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder"
d_single_modal: int
d_shared: int
n_modalities: int = 2
modalities: dict[str, int]

@property
def d_sae(self) -> int:
return self.d_single_modal * self.n_modalities + self.d_shared
return sum(self.modalities.values())


class InitializerConfig(BaseConfig):
Expand All @@ -131,7 +129,7 @@ class TrainerConfig(BaseConfig):
k_warmup_steps: int | float = 0.1
use_batch_norm_mse: bool = True

lr: float = 0.0004
lr: float | dict[str, float] = 0.0004
betas: Tuple[float, float] = (0.9, 0.999)
lr_scheduler_name: Literal[
"constant",
Expand Down
6 changes: 6 additions & 0 deletions src/lm_saes/crosscoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def encode(
Float[torch.Tensor, "batch seq_len d_model"],
],
return_hidden_pre: Literal[False] = False,
**kwargs,
) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]: ...

@overload
Expand All @@ -50,6 +51,7 @@ def encode(
Float[torch.Tensor, "batch seq_len d_model"],
],
return_hidden_pre: Literal[True],
**kwargs,
) -> tuple[
Union[
Float[torch.Tensor, "batch d_sae"],
Expand All @@ -68,6 +70,7 @@ def encode(
Float[torch.Tensor, "batch seq_len d_model"],
],
return_hidden_pre: bool = False,
**kwargs,
) -> Union[
Float[torch.Tensor, "batch d_sae"],
Float[torch.Tensor, "batch seq_len d_sae"],
Expand Down Expand Up @@ -130,6 +133,7 @@ def compute_loss(
use_batch_norm_mse: bool = False,
lp: int = 1,
return_aux_data: Literal[True] = True,
**kwargs,
) -> tuple[
Float[torch.Tensor, " batch"],
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
Expand All @@ -143,6 +147,7 @@ def compute_loss(
use_batch_norm_mse: bool = False,
lp: int = 1,
return_aux_data: Literal[False],
**kwargs,
) -> Float[torch.Tensor, " batch"]: ...

def compute_loss(
Expand All @@ -159,6 +164,7 @@ def compute_loss(
use_batch_norm_mse: bool = False,
lp: int = 1,
return_aux_data: bool = True,
**kwargs,
) -> Union[
Float[torch.Tensor, " batch"],
tuple[
Expand Down
18 changes: 5 additions & 13 deletions src/lm_saes/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from transformer_lens import HookedTransformer
from wandb.sdk.wandb_run import Run

from lm_saes.config import EvalConfig, MixCoderConfig, SAEConfig
from lm_saes.mixcoder import MixCoder
from lm_saes.config import EvalConfig
from lm_saes.sae import SparseAutoEncoder


Expand Down Expand Up @@ -55,11 +54,12 @@ def log_metric(metric: str, value: float) -> None:
# 2. Get activations and compute reconstructions
activation_in = activation_dict[sae.cfg.hook_point_in][useful_token_mask]
activation_out = activation_dict[sae.cfg.hook_point_out][useful_token_mask]
feature_acts = sae.encode(activation_in)
tokens = activation_dict["tokens"][useful_token_mask]
feature_acts = sae.encode(activation_in, tokens=tokens)
reconstructed = (
log_info.pop("reconstructed")[useful_token_mask]
if "reconstructed" in log_info
else sae.forward(activation_in)
else sae.forward(activation_in, tokens=tokens)
)

# 3. Compute sparsity metrics
Expand Down Expand Up @@ -156,15 +156,7 @@ def _evaluate_tokens(
names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out],
return_cache_object=False,
)
reconstructed_activations: Tensor | None = None
if isinstance(sae.cfg, SAEConfig):
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in])

elif isinstance(sae.cfg, MixCoderConfig):
assert isinstance(sae, MixCoder)
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids)

assert reconstructed_activations is not None
reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids)

def replace_hook(activations: Tensor, hook_point: str) -> Tensor:
return torch.where(useful_token_mask, reconstructed_activations, activations)
Expand Down
71 changes: 45 additions & 26 deletions src/lm_saes/initializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Iterable, List
import warnings
from typing import Any, Dict, Iterable, List

import torch
from torch import Tensor
Expand All @@ -10,28 +11,30 @@
parallelize_module,
)

from lm_saes.config import BaseSAEConfig, InitializerConfig, SAEConfig
from lm_saes.config import BaseSAEConfig, InitializerConfig
from lm_saes.mixcoder import MixCoder
from lm_saes.sae import SparseAutoEncoder
from lm_saes.utils.misc import calculate_activation_norm
from lm_saes.utils.misc import calculate_activation_norm, get_modality_indices


class Initializer:
def __init__(self, cfg: InitializerConfig):
self.cfg = cfg

@torch.no_grad()
def initialize_parameters(self, sae: SparseAutoEncoder):
def initialize_parameters(self, sae: SparseAutoEncoder, mixcoder_settings: dict[str, Any] | None = None):
"""Initialize the parameters of the SAE.
Only used when the state is "training" to initialize sae.
"""
torch.nn.init.kaiming_uniform_(sae.encoder.weight)
torch.nn.init.kaiming_uniform_(sae.decoder.weight)
torch.nn.init.zeros_(sae.encoder.bias)
if sae.cfg.use_decoder_bias:
torch.nn.init.zeros_(sae.decoder.bias)
if sae.cfg.use_glu_encoder:
torch.nn.init.kaiming_uniform_(sae.encoder_glu.weight)
torch.nn.init.zeros_(sae.encoder_glu.bias)

if sae.cfg.sae_type == "mixcoder":
assert mixcoder_settings is not None
assert "model_name" in mixcoder_settings and "tokenizer" in mixcoder_settings
modality_indices = get_modality_indices(mixcoder_settings["tokenizer"], mixcoder_settings["model_name"])
sae.init_parameters(modality_indices=modality_indices)

else:
sae.init_parameters()

if self.cfg.init_decoder_norm:
sae.set_decoder_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True)
Expand All @@ -48,14 +51,18 @@ def initialize_parameters(self, sae: SparseAutoEncoder):
def initialize_tensor_parallel(self, sae: SparseAutoEncoder, device_mesh: DeviceMesh | None = None):
if not device_mesh or device_mesh["model"].size(0) == 1:
return sae
sae.device_mesh = device_mesh
plan = {
"encoder": ColwiseParallel(output_layouts=Replicate()),
"decoder": RowwiseParallel(input_layouts=Replicate()),
}
if sae.cfg.use_glu_encoder:
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore
if sae.cfg.sae_type == "sae":
sae.device_mesh = device_mesh
plan = {
"encoder": ColwiseParallel(output_layouts=Replicate()),
"decoder": RowwiseParallel(input_layouts=Replicate()),
}
if sae.cfg.use_glu_encoder:
plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate())
sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore

elif sae.cfg.sae_type == "mixcoder":
warnings.warn("MixCoder is not supported for tensor parallel initialization.")
return sae

@torch.no_grad()
Expand All @@ -67,6 +74,7 @@ def initialization_search(self, sae: SparseAutoEncoder, activation_batch: Dict[s
activation_batch[sae.cfg.hook_point_in],
activation_batch[sae.cfg.hook_point_out],
)
tokens = activation_batch["tokens"]
if self.cfg.init_decoder_norm is None:
assert sae.cfg.sparsity_include_decoder_norm, "Decoder norm must be included in sparsity loss"
if not self.cfg.init_encoder_with_decoder_transpose or sae.cfg.hook_point_in != sae.cfg.hook_point_out:
Expand All @@ -80,7 +88,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:
sae.init_encoder_with_decoder_transpose()
if sae.cfg.sae_type == "crosscoder":
sae.initialize_with_same_weight_across_layers()
mse = sae.compute_loss(activation_batch)[1][0]["l_rec"].mean().item()
mse = sae.compute_loss(activation_batch, tokens=tokens)[1][0]["l_rec"].mean().item()
losses[norm] = mse
best_norm = min(losses, key=losses.get) # type: ignore
return best_norm
Expand All @@ -97,7 +105,8 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:

sae.set_decoder_to_fixed_norm(best_norm_fine_grained, force_exact=True)

if self.cfg.bias_init_method == "geometric_median":
if self.cfg.bias_init_method == "geometric_median" and sae.cfg.sae_type != "mixcoder":
# TODO: add support for MixCoder
sae.decoder.bias.data = (
sae.compute_norm_factor(activation_out, hook_point=sae.cfg.hook_point_out) * activation_out
).mean(0)
Expand All @@ -116,9 +125,15 @@ def grid_search_best_init_norm(search_range: List[float]) -> float:

@torch.no_grad()
def initialize_jump_relu_threshold(self, sae: SparseAutoEncoder, activation_batch: Dict[str, Tensor]):
# TODO: add support for MixCoder
if sae.cfg.sae_type == "mixcoder":
warnings.warn("MixCoder is not supported for jump_relu_threshold initialization.")
return sae

activation_in = activation_batch[sae.cfg.hook_point_in]
tokens = activation_batch["tokens"]
batch_size = activation_in.size(0)
_, hidden_pre = sae.encode(activation_in, return_hidden_pre=True)
_, hidden_pre = sae.encode(activation_in, return_hidden_pre=True, tokens=tokens)
hidden_pre = torch.clamp(hidden_pre, min=0.0)
hidden_pre = hidden_pre.flatten()
threshold = hidden_pre.topk(k=batch_size * sae.cfg.top_k).values[-1]
Expand All @@ -131,6 +146,7 @@ def initialize_sae_from_config(
activation_stream: Iterable[dict[str, Tensor]] | None = None,
activation_norm: dict[str, float] | None = None,
device_mesh: DeviceMesh | None = None,
mixcoder_settings: dict[str, Any] | None = None,
):
"""
Initialize the SAE from the SAE config.
Expand All @@ -141,14 +157,16 @@ def initialize_sae_from_config(
device_mesh (DeviceMesh | None): The device mesh.
"""
sae = None # type: ignore
if isinstance(cfg, SAEConfig):
if cfg.sae_type == "sae":
sae = SparseAutoEncoder.from_config(cfg)
elif cfg.sae_type == "mixcoder":
sae = MixCoder.from_config(cfg)
else:
# TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
pass
if self.cfg.state == "training":
if cfg.sae_pretrained_name_or_path is None:
sae: SparseAutoEncoder = self.initialize_parameters(sae)
sae: SparseAutoEncoder = self.initialize_parameters(sae, mixcoder_settings=mixcoder_settings)
if sae.cfg.norm_activation == "dataset-wise":
if activation_norm is None:
assert (
Expand Down Expand Up @@ -179,7 +197,8 @@ def initialize_sae_from_config(
), "Activation iterator must be provided for jump_relu_threshold initialization"
activation_batch = next(iter(activation_stream))
self.initialize_jump_relu_threshold(sae, activation_batch)
sae.cfg.act_fn = "jumprelu"
if cfg.sae_type != "mixcoder": # TODO: add support for MixCoder
sae.cfg.act_fn = "jumprelu"

sae = self.initialize_tensor_parallel(sae, device_mesh)
return sae
Loading
Loading