diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 9b3187fb..d409a3ec 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -104,13 +104,11 @@ class CrossCoderConfig(BaseSAEConfig): class MixCoderConfig(BaseSAEConfig): sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder" - d_single_modal: int - d_shared: int - n_modalities: int = 2 + modalities: dict[str, int] @property def d_sae(self) -> int: - return self.d_single_modal * self.n_modalities + self.d_shared + return sum(self.modalities.values()) class InitializerConfig(BaseConfig): @@ -131,7 +129,7 @@ class TrainerConfig(BaseConfig): k_warmup_steps: int | float = 0.1 use_batch_norm_mse: bool = True - lr: float = 0.0004 + lr: float | dict[str, float] = 0.0004 betas: Tuple[float, float] = (0.9, 0.999) lr_scheduler_name: Literal[ "constant", diff --git a/src/lm_saes/crosscoder.py b/src/lm_saes/crosscoder.py index 83286e5e..75a53c5d 100644 --- a/src/lm_saes/crosscoder.py +++ b/src/lm_saes/crosscoder.py @@ -40,6 +40,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: Literal[False] = False, + **kwargs, ) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]: ... @overload @@ -50,6 +51,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: Literal[True], + **kwargs, ) -> tuple[ Union[ Float[torch.Tensor, "batch d_sae"], @@ -68,6 +70,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: bool = False, + **kwargs, ) -> Union[ Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"], @@ -130,6 +133,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: Literal[True] = True, + **kwargs, ) -> tuple[ Float[torch.Tensor, " batch"], tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], @@ -143,6 +147,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: Literal[False], + **kwargs, ) -> Float[torch.Tensor, " batch"]: ... def compute_loss( @@ -159,6 +164,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: bool = True, + **kwargs, ) -> Union[ Float[torch.Tensor, " batch"], tuple[ diff --git a/src/lm_saes/evaluator.py b/src/lm_saes/evaluator.py index c58ce48d..58393387 100644 --- a/src/lm_saes/evaluator.py +++ b/src/lm_saes/evaluator.py @@ -6,8 +6,7 @@ from transformer_lens import HookedTransformer from wandb.sdk.wandb_run import Run -from lm_saes.config import EvalConfig, MixCoderConfig, SAEConfig -from lm_saes.mixcoder import MixCoder +from lm_saes.config import EvalConfig from lm_saes.sae import SparseAutoEncoder @@ -55,11 +54,12 @@ def log_metric(metric: str, value: float) -> None: # 2. Get activations and compute reconstructions activation_in = activation_dict[sae.cfg.hook_point_in][useful_token_mask] activation_out = activation_dict[sae.cfg.hook_point_out][useful_token_mask] - feature_acts = sae.encode(activation_in) + tokens = activation_dict["tokens"][useful_token_mask] + feature_acts = sae.encode(activation_in, tokens=tokens) reconstructed = ( log_info.pop("reconstructed")[useful_token_mask] if "reconstructed" in log_info - else sae.forward(activation_in) + else sae.forward(activation_in, tokens=tokens) ) # 3. Compute sparsity metrics @@ -156,15 +156,7 @@ def _evaluate_tokens( names_filter=[sae.cfg.hook_point_in, sae.cfg.hook_point_out], return_cache_object=False, ) - reconstructed_activations: Tensor | None = None - if isinstance(sae.cfg, SAEConfig): - reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in]) - - elif isinstance(sae.cfg, MixCoderConfig): - assert isinstance(sae, MixCoder) - reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids) - - assert reconstructed_activations is not None + reconstructed_activations = sae.forward(cache[sae.cfg.hook_point_in], tokens=input_ids) def replace_hook(activations: Tensor, hook_point: str) -> Tensor: return torch.where(useful_token_mask, reconstructed_activations, activations) diff --git a/src/lm_saes/initializer.py b/src/lm_saes/initializer.py index 1dffa291..fcae98b6 100644 --- a/src/lm_saes/initializer.py +++ b/src/lm_saes/initializer.py @@ -1,4 +1,5 @@ -from typing import Dict, Iterable, List +import warnings +from typing import Any, Dict, Iterable, List import torch from torch import Tensor @@ -10,9 +11,10 @@ parallelize_module, ) -from lm_saes.config import BaseSAEConfig, InitializerConfig, SAEConfig +from lm_saes.config import BaseSAEConfig, InitializerConfig +from lm_saes.mixcoder import MixCoder from lm_saes.sae import SparseAutoEncoder -from lm_saes.utils.misc import calculate_activation_norm +from lm_saes.utils.misc import calculate_activation_norm, get_modality_indices class Initializer: @@ -20,18 +22,19 @@ def __init__(self, cfg: InitializerConfig): self.cfg = cfg @torch.no_grad() - def initialize_parameters(self, sae: SparseAutoEncoder): + def initialize_parameters(self, sae: SparseAutoEncoder, mixcoder_settings: dict[str, Any] | None = None): """Initialize the parameters of the SAE. Only used when the state is "training" to initialize sae. """ - torch.nn.init.kaiming_uniform_(sae.encoder.weight) - torch.nn.init.kaiming_uniform_(sae.decoder.weight) - torch.nn.init.zeros_(sae.encoder.bias) - if sae.cfg.use_decoder_bias: - torch.nn.init.zeros_(sae.decoder.bias) - if sae.cfg.use_glu_encoder: - torch.nn.init.kaiming_uniform_(sae.encoder_glu.weight) - torch.nn.init.zeros_(sae.encoder_glu.bias) + + if sae.cfg.sae_type == "mixcoder": + assert mixcoder_settings is not None + assert "model_name" in mixcoder_settings and "tokenizer" in mixcoder_settings + modality_indices = get_modality_indices(mixcoder_settings["tokenizer"], mixcoder_settings["model_name"]) + sae.init_parameters(modality_indices=modality_indices) + + else: + sae.init_parameters() if self.cfg.init_decoder_norm: sae.set_decoder_to_fixed_norm(self.cfg.init_decoder_norm, force_exact=True) @@ -48,14 +51,18 @@ def initialize_parameters(self, sae: SparseAutoEncoder): def initialize_tensor_parallel(self, sae: SparseAutoEncoder, device_mesh: DeviceMesh | None = None): if not device_mesh or device_mesh["model"].size(0) == 1: return sae - sae.device_mesh = device_mesh - plan = { - "encoder": ColwiseParallel(output_layouts=Replicate()), - "decoder": RowwiseParallel(input_layouts=Replicate()), - } - if sae.cfg.use_glu_encoder: - plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) - sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore + if sae.cfg.sae_type == "sae": + sae.device_mesh = device_mesh + plan = { + "encoder": ColwiseParallel(output_layouts=Replicate()), + "decoder": RowwiseParallel(input_layouts=Replicate()), + } + if sae.cfg.use_glu_encoder: + plan["encoder_glu"] = ColwiseParallel(output_layouts=Replicate()) + sae = parallelize_module(sae, device_mesh=device_mesh["model"], parallelize_plan=plan) # type: ignore + + elif sae.cfg.sae_type == "mixcoder": + warnings.warn("MixCoder is not supported for tensor parallel initialization.") return sae @torch.no_grad() @@ -67,6 +74,7 @@ def initialization_search(self, sae: SparseAutoEncoder, activation_batch: Dict[s activation_batch[sae.cfg.hook_point_in], activation_batch[sae.cfg.hook_point_out], ) + tokens = activation_batch["tokens"] if self.cfg.init_decoder_norm is None: assert sae.cfg.sparsity_include_decoder_norm, "Decoder norm must be included in sparsity loss" if not self.cfg.init_encoder_with_decoder_transpose or sae.cfg.hook_point_in != sae.cfg.hook_point_out: @@ -80,7 +88,7 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: sae.init_encoder_with_decoder_transpose() if sae.cfg.sae_type == "crosscoder": sae.initialize_with_same_weight_across_layers() - mse = sae.compute_loss(activation_batch)[1][0]["l_rec"].mean().item() + mse = sae.compute_loss(activation_batch, tokens=tokens)[1][0]["l_rec"].mean().item() losses[norm] = mse best_norm = min(losses, key=losses.get) # type: ignore return best_norm @@ -97,7 +105,8 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: sae.set_decoder_to_fixed_norm(best_norm_fine_grained, force_exact=True) - if self.cfg.bias_init_method == "geometric_median": + if self.cfg.bias_init_method == "geometric_median" and sae.cfg.sae_type != "mixcoder": + # TODO: add support for MixCoder sae.decoder.bias.data = ( sae.compute_norm_factor(activation_out, hook_point=sae.cfg.hook_point_out) * activation_out ).mean(0) @@ -116,9 +125,15 @@ def grid_search_best_init_norm(search_range: List[float]) -> float: @torch.no_grad() def initialize_jump_relu_threshold(self, sae: SparseAutoEncoder, activation_batch: Dict[str, Tensor]): + # TODO: add support for MixCoder + if sae.cfg.sae_type == "mixcoder": + warnings.warn("MixCoder is not supported for jump_relu_threshold initialization.") + return sae + activation_in = activation_batch[sae.cfg.hook_point_in] + tokens = activation_batch["tokens"] batch_size = activation_in.size(0) - _, hidden_pre = sae.encode(activation_in, return_hidden_pre=True) + _, hidden_pre = sae.encode(activation_in, return_hidden_pre=True, tokens=tokens) hidden_pre = torch.clamp(hidden_pre, min=0.0) hidden_pre = hidden_pre.flatten() threshold = hidden_pre.topk(k=batch_size * sae.cfg.top_k).values[-1] @@ -131,6 +146,7 @@ def initialize_sae_from_config( activation_stream: Iterable[dict[str, Tensor]] | None = None, activation_norm: dict[str, float] | None = None, device_mesh: DeviceMesh | None = None, + mixcoder_settings: dict[str, Any] | None = None, ): """ Initialize the SAE from the SAE config. @@ -141,14 +157,16 @@ def initialize_sae_from_config( device_mesh (DeviceMesh | None): The device mesh. """ sae = None # type: ignore - if isinstance(cfg, SAEConfig): + if cfg.sae_type == "sae": sae = SparseAutoEncoder.from_config(cfg) + elif cfg.sae_type == "mixcoder": + sae = MixCoder.from_config(cfg) else: # TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc. pass if self.cfg.state == "training": if cfg.sae_pretrained_name_or_path is None: - sae: SparseAutoEncoder = self.initialize_parameters(sae) + sae: SparseAutoEncoder = self.initialize_parameters(sae, mixcoder_settings=mixcoder_settings) if sae.cfg.norm_activation == "dataset-wise": if activation_norm is None: assert ( @@ -179,7 +197,8 @@ def initialize_sae_from_config( ), "Activation iterator must be provided for jump_relu_threshold initialization" activation_batch = next(iter(activation_stream)) self.initialize_jump_relu_threshold(sae, activation_batch) - sae.cfg.act_fn = "jumprelu" + if cfg.sae_type != "mixcoder": # TODO: add support for MixCoder + sae.cfg.act_fn = "jumprelu" sae = self.initialize_tensor_parallel(sae, device_mesh) return sae diff --git a/src/lm_saes/mixcoder.py b/src/lm_saes/mixcoder.py index c34a9d21..68777946 100644 --- a/src/lm_saes/mixcoder.py +++ b/src/lm_saes/mixcoder.py @@ -1,8 +1,370 @@ +import math +from typing import Literal, MutableMapping, Union, cast, overload + +import torch +from jaxtyping import Float +from torch import nn +from torch.distributed.tensor import DTensor +from transformer_lens.hook_points import HookPoint + from lm_saes.config import MixCoderConfig from lm_saes.sae import SparseAutoEncoder class MixCoder(SparseAutoEncoder): def __init__(self, cfg: MixCoderConfig): + """A multi-modal sparse autoencoder that handles different modalities with separate encoder/decoder pairs. + + This class extends the base SparseAutoEncoder to support multiple modalities, where each modality + has its own encoder and decoder, plus a shared component. The architecture allows for modality-specific + feature extraction while maintaining shared representations. + + Attributes: + modality_index (dict[str, tuple[int, int]]): Maps modality names to their index ranges in the feature space + modality_indices (dict[str, torch.Tensor]): Maps modality names to their token indices + encoder (MutableMapping[str, nn.Linear]): Dictionary of encoders for each modality + decoder (MutableMapping[str, nn.Linear]): Dictionary of decoders for each modality + encoder_glu (MutableMapping[str, nn.Linear]): Optional GLU gates for encoders + """ super().__init__(cfg) - pass + # remove encoder and decoder initialized by super() + del self.encoder + del self.decoder + if cfg.use_glu_encoder: + del self.encoder_glu + + # initialize new encoder and decoder + self.cfg = cfg + self.modality_index = {} + self.modality_indices = {} + self.encoder = cast(MutableMapping[str, nn.Linear], nn.ModuleDict()) + self.decoder = cast(MutableMapping[str, nn.Linear], nn.ModuleDict()) + self.encoder_glu = cast(MutableMapping[str, nn.Linear], nn.ModuleDict()) + self.hook_hidden_pre = HookPoint() + self.hook_feature_acts = HookPoint() + self.hook_reconstructed = HookPoint() + for modality, d_modality in cfg.modalities.items(): + self.encoder[modality] = nn.Linear(cfg.d_model, d_modality, bias=True, device=cfg.device, dtype=cfg.dtype) + self.decoder[modality] = nn.Linear( + d_modality, cfg.d_model, bias=cfg.use_decoder_bias, device=cfg.device, dtype=cfg.dtype + ) + if cfg.use_glu_encoder: + self.encoder_glu[modality] = nn.Linear( + cfg.d_model, d_modality, bias=True, device=cfg.device, dtype=cfg.dtype + ) + + index = 0 + for modality, d_modality in cfg.modalities.items(): + if modality == "shared": + continue + self.modality_index[modality] = (index, index + d_modality) + index += d_modality + + assert index + cfg.modalities["shared"] == cfg.d_sae + self.modality_index["shared"] = (index, cfg.d_sae) + + @torch.no_grad() + def set_decoder_to_fixed_norm(self, value: float, force_exact: bool): + for modality in self.cfg.modalities.keys(): + self._set_decoder_to_fixed_norm(self.decoder[modality], value, force_exact) + + @torch.no_grad() + def set_encoder_to_fixed_norm(self, value: float): + for modality in self.cfg.modalities.keys(): + self._set_encoder_to_fixed_norm(self.encoder[modality], value) + + @torch.no_grad() + def _get_full_state_dict(self): + state_dict = self.state_dict() + if self.device_mesh and self.device_mesh["model"].size(0) > 1: + state_dict = {k: v.full_tensor() if isinstance(v, DTensor) else v for k, v in state_dict.items()} + + if self.dataset_average_activation_norm is not None: + for hook_point, value in self.dataset_average_activation_norm.items(): + state_dict[f"dataset_average_activation_norm.{hook_point}"] = torch.tensor(value) + + for modality, indices in self.modality_indices.items(): + state_dict[f"modality_indices.{modality}"] = indices + + if not self.cfg.sparsity_include_decoder_norm: + for modality in self.cfg.modalities.keys(): + state_dict[f"decoder.{modality}.weight"] = self.decoder[modality].weight.data.clone() + decoder_norm = torch.norm(state_dict[f"decoder.{modality}.weight"], p=2, dim=0, keepdim=True) + state_dict[f"decoder.{modality}.weight"] = state_dict[f"decoder.{modality}.weight"] / decoder_norm + + return cast(dict[str, torch.Tensor], state_dict) + + @torch.no_grad() + def _load_full_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: + super()._load_full_state_dict(state_dict) + modality_indices_keys = [k for k in state_dict.keys() if k.startswith("modality_indices.")] + assert len(modality_indices_keys) == len(self.cfg.modalities) - 1 # shared modality is not included + self.modality_indices = {key.split(".", 1)[1]: state_dict[key] for key in modality_indices_keys} + + @torch.no_grad() + def transform_to_unit_decoder_norm(self): + for modality in self.cfg.modalities.keys(): + self._transform_to_unit_decoder_norm(self.encoder[modality], self.decoder[modality]) + + @torch.no_grad() + def standardize_parameters_of_dataset_norm(self, dataset_average_activation_norm: dict[str, float] | None): + assert self.cfg.norm_activation == "dataset-wise" + assert self.dataset_average_activation_norm is not None or dataset_average_activation_norm is not None + if dataset_average_activation_norm is not None: + self.set_dataset_average_activation_norm(dataset_average_activation_norm) + assert self.dataset_average_activation_norm is not None + input_norm_factor: float = ( + math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[self.cfg.hook_point_in] + ) + output_norm_factor: float = ( + math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[self.cfg.hook_point_out] + ) + + for modality in self.cfg.modalities.keys(): + self.encoder[modality].bias.data = self.encoder[modality].bias.data / input_norm_factor + if self.cfg.use_decoder_bias: + self.decoder[modality].bias.data = self.decoder[modality].bias.data / output_norm_factor + self.decoder[modality].weight.data = ( + self.decoder[modality].weight.data * input_norm_factor / output_norm_factor + ) + self.cfg.norm_activation = "inference" + + @torch.no_grad() + def init_encoder_with_decoder_transpose(self): + for modality in self.cfg.modalities.keys(): + self._init_encoder_with_decoder_transpose(self.encoder[modality], self.decoder[modality]) + + @torch.no_grad() + def log_statistics(self): + log_dict = {} + for modality in self.cfg.modalities.keys(): + log_dict[f"metrics/{modality}.encoder_norm"] = self._encoder_norm(self.encoder[modality]).mean().item() + log_dict[f"metrics/{modality}.encoder_bias_norm"] = self.encoder[modality].bias.norm().item() + log_dict[f"metrics/{modality}.decoder_norm"] = self._decoder_norm(self.decoder[modality]).mean().item() + if self.cfg.use_decoder_bias: + log_dict[f"metrics/{modality}.decoder_bias_norm"] = self.decoder[modality].bias.norm().item() + if "topk" in self.cfg.act_fn: + log_dict["sparsity/k"] = self.current_k + else: + log_dict["sparsity/l1_coefficient"] = self.current_l1_coefficient + return log_dict + + def get_modality_index(self) -> dict[str, tuple[int, int]]: + """Get the mapping from modality names to their index ranges in the feature space. + + Returns: + dict[str, tuple[int, int]]: A dictionary mapping modality names (e.g. 'text', 'image') + to tuples of (start_idx, end_idx) that define the index range in the feature space. + The shared modality should be the last one. + """ + return self.modality_index + + def _get_modality_activation( + self, + activation: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + tokens: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + modality: str, + ) -> Union[Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"]]: + """Get the activation of a specific modality. + + Args: + activation: The activation tensor to be masked + tokens: The token tensor to use for masking + modality: The name of the modality to get the activation for + + Returns: + The activation of the specified modality. The shape is the same as the input activation. + """ + activation_mask = torch.isin(tokens, self.modality_indices[modality]) + return activation_mask.unsqueeze(1) * activation + + @overload + def encode( + self, + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + return_hidden_pre: Literal[False] = False, + **kwargs, + ) -> Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ]: ... + + @overload + def encode( + self, + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + return_hidden_pre: Literal[True], + **kwargs, + ) -> tuple[ + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + ]: ... + + def encode( + self, + x: Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ], + return_hidden_pre: bool = False, + **kwargs, + ) -> Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + dict[str, Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]], + tuple[ + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + ], + ]: + """Encodes input tensors into sparse feature representations for each modality. + + The encoding process: + 1. Separates input by modality using token masks + 2. Applies modality-specific encoding to each modality activation ('shared' modality included) + 3. Combines modality-specific and shared features + + Args: + x: Input tensor to encode + return_hidden_pre: If True, returns both feature activations and pre-activation values + **kwargs: Must contain 'tokens' for modality masking + + Returns: + Either feature activations alone, or tuple of (feature_acts, hidden_pre) if return_hidden_pre=True + """ + assert "tokens" in kwargs + tokens = kwargs["tokens"] + feature_acts = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype) + hidden_pre = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype) + input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in) + x = x * input_norm_factor + for modality, (start, end) in self.modality_index.items(): + if modality == "shared": + # shared modality is not encoded directly but summed up during other modalities' encoding + continue + x_modality = self._get_modality_activation(x, tokens, modality) + if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: + modality_bias = ( + self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore + if isinstance(self.decoder[modality].bias, DTensor) + else self.decoder[modality].bias + ) + shared_bias = ( + self.decoder["shared"].bias.to_local() + if isinstance(self.decoder["shared"].bias, DTensor) + else self.decoder["shared"].bias + ) + x_modality = x_modality - modality_bias - shared_bias + + hidden_pre_modality = self.encoder[modality](x_modality) + hidden_pre_shared = self.encoder["shared"](x_modality) + + if self.cfg.use_glu_encoder: + hidden_pre_modality_glu = torch.sigmoid(self.encoder_glu[modality](x_modality)) + hidden_pre_modality = hidden_pre_modality * hidden_pre_modality_glu + hidden_pre_shared_glu = torch.sigmoid(self.encoder_glu["shared"](x_modality)) + hidden_pre_shared = hidden_pre_shared * hidden_pre_shared_glu + + if self.cfg.sparsity_include_decoder_norm: + true_feature_acts_modality = hidden_pre_modality * self._decoder_norm(decoder=self.decoder[modality]) + true_feature_acts_shared = hidden_pre_shared * self._decoder_norm(decoder=self.decoder["shared"]) + else: + true_feature_acts_modality = hidden_pre_modality + true_feature_acts_shared = hidden_pre_shared + + true_feature_acts_concat = torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) + activation_mask_concat = self.activation_function(true_feature_acts_concat) + feature_acts_concat = true_feature_acts_concat * activation_mask_concat + + feature_acts_modality = feature_acts_concat[:, : self.cfg.modalities[modality]] + feature_acts_shared = feature_acts_concat[:, self.cfg.modalities[modality] :] + assert feature_acts_shared.shape[1] == self.cfg.modalities["shared"] + + feature_acts[:, start:end] += feature_acts_modality + hidden_pre[:, start:end] += hidden_pre_modality + + shared_start, shared_end = self.modality_index["shared"] + feature_acts[:, shared_start:shared_end] += feature_acts_shared + hidden_pre[:, shared_start:shared_end] += hidden_pre_shared + + hidden_pre = self.hook_hidden_pre(hidden_pre) + feature_acts = self.hook_feature_acts(feature_acts) + if return_hidden_pre: + return feature_acts, hidden_pre + return feature_acts + + def decode( + self, + feature_acts: Union[ + Float[torch.Tensor, "batch d_sae"], + Float[torch.Tensor, "batch seq_len d_sae"], + ], + **kwargs, + ) -> Union[ + Float[torch.Tensor, "batch d_model"], + Float[torch.Tensor, "batch seq_len d_model"], + ]: + reconstructed = torch.zeros( + feature_acts.shape[0], self.cfg.d_model, device=feature_acts.device, dtype=feature_acts.dtype + ) + for modality, (start, end) in self.modality_index.items(): + feature_acts_modality = feature_acts[:, start:end] # batch x d_modality + reconstructed_modality = self.decoder[modality](feature_acts_modality) # batch x d_model + reconstructed += reconstructed_modality + reconstructed = self.hook_reconstructed(reconstructed) + return reconstructed + + @torch.no_grad() + def init_parameters(self, **kwargs): + assert "modality_indices" in kwargs + modality_indices: dict[str, torch.Tensor] = kwargs["modality_indices"] + for modality in self.cfg.modalities.keys(): + torch.nn.init.kaiming_uniform_(self.encoder[modality].weight) + torch.nn.init.kaiming_uniform_(self.decoder[modality].weight) + torch.nn.init.zeros_(self.encoder[modality].bias) + if self.cfg.use_decoder_bias: + torch.nn.init.zeros_(self.decoder[modality].bias) + if self.cfg.use_glu_encoder: + torch.nn.init.kaiming_uniform_(self.encoder_glu[modality].weight) + torch.nn.init.zeros_(self.encoder_glu[modality].bias) + + for modality, indices in modality_indices.items(): + self.modality_indices[modality] = indices.to(self.cfg.device) + + @classmethod + def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs): + cfg = MixCoderConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs) + return cls.from_config(cfg) + + def get_parameters(self): + params = [] + for modality in self.cfg.modalities.keys(): + modality_params = list(self.encoder[modality].parameters()) + list(self.decoder[modality].parameters()) + if self.cfg.use_glu_encoder: + modality_params += list(self.encoder_glu[modality].parameters()) + params.append({"params": modality_params, "modality": modality}) + return params diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 7a8e667a..99ec3b73 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -6,6 +6,7 @@ import safetensors.torch as safe import torch +from fsspec.spec import Any from jaxtyping import Float from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor @@ -107,8 +108,9 @@ def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Ten def topk_activation(x: torch.Tensor): x = torch.clamp(x, min=0.0) k = x.shape[-1] - self.current_k + 1 - k_th_value, _ = torch.kthvalue(x, k=k, dim=-1, keepdim=True) - return x.ge(k_th_value).to(x.dtype) + k_th_value, _ = torch.kthvalue(x, k=k, dim=-1) + k_th_value = k_th_value.unsqueeze(dim=1) + return x.ge(k_th_value) return topk_activation @@ -121,7 +123,7 @@ def topk_activation(x: torch.Tensor): x = torch.clamp(x, min=0.0) k = x.numel() - self.current_k * batch_size + 1 k_th_value, _ = torch.kthvalue(x.flatten(), k=k, dim=-1) - return x.ge(k_th_value).to(x.dtype) + return x.ge(k_th_value) return topk_activation @@ -325,6 +327,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: Literal[False] = False, + **kwargs, ) -> Union[Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"]]: ... @overload @@ -335,6 +338,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: Literal[True], + **kwargs, ) -> tuple[ Union[ Float[torch.Tensor, "batch d_sae"], @@ -353,6 +357,7 @@ def encode( Float[torch.Tensor, "batch seq_len d_model"], ], return_hidden_pre: bool = False, + **kwargs, ) -> Union[ Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"], @@ -396,6 +401,7 @@ def decode( Float[torch.Tensor, "batch d_sae"], Float[torch.Tensor, "batch seq_len d_sae"], ], + **kwargs, ) -> Union[ Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], @@ -416,8 +422,8 @@ def forward( Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], ]: - feature_acts = self.encode(x) - reconstructed = self.decode(feature_acts) + feature_acts = self.encode(x, **kwargs) + reconstructed = self.decode(feature_acts, **kwargs) return reconstructed @overload @@ -428,6 +434,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: Literal[True] = True, + **kwargs, ) -> tuple[ Float[torch.Tensor, " batch"], tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], @@ -441,6 +448,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: Literal[False], + **kwargs, ) -> Float[torch.Tensor, " batch"]: ... def compute_loss( @@ -457,6 +465,7 @@ def compute_loss( use_batch_norm_mse: bool = False, lp: int = 1, return_aux_data: bool = True, + **kwargs, ) -> Union[ Float[torch.Tensor, " batch"], tuple[ @@ -466,8 +475,8 @@ def compute_loss( ]: # may be overridden by subclasses x: torch.Tensor = batch[self.cfg.hook_point_in] label: torch.Tensor = batch[self.cfg.hook_point_out] - feature_acts, hidden_pre = self.encode(x, return_hidden_pre=True) - reconstructed = self.decode(feature_acts) + feature_acts, hidden_pre = self.encode(x, return_hidden_pre=True, **kwargs) + reconstructed = self.decode(feature_acts, **kwargs) label_norm_factor: torch.Tensor = self.compute_norm_factor(label, hook_point=self.cfg.hook_point_out) label_normed = label * label_norm_factor l_rec = (reconstructed - label_normed).pow(2) @@ -510,7 +519,7 @@ def _load_full_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: self.load_state_dict(state_dict, strict=self.cfg.strict_loading) @classmethod - def from_config(cls, cfg: SAEConfig) -> "SparseAutoEncoder": + def from_config(cls, cfg: BaseSAEConfig) -> "SparseAutoEncoder": if cfg.sae_pretrained_name_or_path is None: return cls(cfg) path = parse_pretrained_name_or_path(cfg.sae_pretrained_name_or_path) @@ -571,3 +580,17 @@ def _init_encoder_with_decoder_transpose(self, encoder: torch.nn.Linear, decoder @torch.no_grad() def init_encoder_with_decoder_transpose(self): self._init_encoder_with_decoder_transpose(self.encoder, self.decoder) + + @torch.no_grad() + def init_parameters(self, **kwargs): + torch.nn.init.kaiming_uniform_(self.encoder.weight) + torch.nn.init.kaiming_uniform_(self.decoder.weight) + torch.nn.init.zeros_(self.encoder.bias) + if self.cfg.use_decoder_bias: + torch.nn.init.zeros_(self.decoder.bias) + if self.cfg.use_glu_encoder: + torch.nn.init.kaiming_uniform_(self.encoder_glu.weight) + torch.nn.init.zeros_(self.encoder_glu.bias) + + def get_parameters(self) -> list[dict[str, Any]]: + return [{"params": self.parameters()}] diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index ec8dc610..4b554cf5 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -58,7 +58,17 @@ def calculate_warmup_steps(warmup_steps: float | int) -> int: self.wandb_logger = wandb_logger def _initialize_optimizer(self, sae: SparseAutoEncoder): - optimizer = Adam(sae.parameters(), lr=self.cfg.lr, betas=self.cfg.betas) + # TODO: check if this is correct + if isinstance(self.cfg.lr, float): + optimizer = Adam(sae.get_parameters(), lr=self.cfg.lr, betas=self.cfg.betas) + else: + assert isinstance(self.cfg.lr, dict) + assert sae.cfg.sae_type == "mixcoder" + params = sae.get_parameters() + assert len(params) == len(self.cfg.lr) + for param_group in params: + param_group["lr"] = self.cfg.lr[param_group["modality"]] + optimizer = Adam(params, betas=self.cfg.betas) scheduler = get_scheduler( scheduler_name=self.cfg.lr_scheduler_name, optimizer=optimizer, @@ -82,14 +92,13 @@ def _training_step( ) elif self.k_warmup_steps > 0: assert self.cfg.initial_k is not None, "initial_k must be provided" + assert self.cfg.initial_k >= sae.cfg.top_k, "initial_k must be greater than or equal to top_k" sae.set_current_k( - math.ceil( - max( - 1.0, - self.cfg.initial_k - + (1 - self.cfg.initial_k) / self.k_warmup_steps * self.cur_step, # d_model / top_k - ) - * sae.cfg.top_k + max( + sae.cfg.top_k, + math.ceil( + self.cfg.initial_k + (sae.cfg.top_k - self.cfg.initial_k) / self.k_warmup_steps * self.cur_step, + ), ) ) @@ -98,6 +107,7 @@ def _training_step( lp=self.cfg.lp, use_batch_norm_mse=self.cfg.use_batch_norm_mse, return_aux_data=True, + tokens=batch["tokens"], ) loss_dict = {"loss": loss, "batch_size": batch[sae.cfg.hook_point_in].shape[0]} | loss_data | aux_data return loss_dict @@ -189,6 +199,7 @@ def fit( self.optimizer.zero_grad() loss_dict = self._training_step(sae, batch) loss_dict["loss"].backward() + # TODO: add support for mixcoder to use different clip_grad_norm for each modality loss_dict["grad_norm"] = torch.nn.utils.clip_grad_norm_( sae.parameters(), max_norm=self.cfg.clip_grad_norm if self.cfg.clip_grad_norm > 0 else math.inf, diff --git a/src/lm_saes/utils/misc.py b/src/lm_saes/utils/misc.py index ae26a221..a9f99dbe 100644 --- a/src/lm_saes/utils/misc.py +++ b/src/lm_saes/utils/misc.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist from torch.distributed.nn.functional import all_reduce +from transformers import PreTrainedTokenizerBase def is_master() -> bool: @@ -143,3 +144,22 @@ def calculate_activation_norm( for key in activation_norm: activation_norm[key] = activation_norm[key].mean().item() return activation_norm + + +def get_modality_indices(tokenizer: PreTrainedTokenizerBase, model_name: str) -> dict[str, torch.Tensor]: + modality_indices = {} + if model_name == "facebook/chameleon-7b": + for token_name, token_id in tokenizer.get_vocab().items(): + if token_name.startswith("IMGIMG"): + modality_indices["image"] = ( + [token_id] if "image" not in modality_indices else modality_indices["image"] + [token_id] + ) + else: + modality_indices["text"] = ( + [token_id] if "text" not in modality_indices else modality_indices["text"] + [token_id] + ) + else: + raise ValueError(f"Unsupported model: {model_name}") + for modality in modality_indices: + modality_indices[modality] = torch.tensor(modality_indices[modality], dtype=torch.long) + return modality_indices diff --git a/tests/integration/test_train_sae.py b/tests/integration/test_train_sae.py index 99b8c22b..49052905 100644 --- a/tests/integration/test_train_sae.py +++ b/tests/integration/test_train_sae.py @@ -5,10 +5,9 @@ from pytest_mock import MockerFixture from torch.distributed.device_mesh import init_device_mesh -if not torch.cuda.is_available(): - pytest.skip("CUDA is not available", allow_module_level=True) - -from lm_saes.config import InitializerConfig, SAEConfig, TrainerConfig +# if not torch.cuda.is_available(): +# pytest.skip("CUDA is not available", allow_module_level=True) +from lm_saes.config import InitializerConfig, MixCoderConfig, SAEConfig, TrainerConfig from lm_saes.initializer import Initializer from lm_saes.trainer import Trainer @@ -20,7 +19,7 @@ def sae_config() -> SAEConfig: hook_point_out="out", d_model=2, expansion_factor=2, - device="cuda", + device="cpu", dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests act_fn="topk", norm_activation="dataset-wise", @@ -29,6 +28,22 @@ def sae_config() -> SAEConfig: ) +@pytest.fixture +def mixcoder_config() -> MixCoderConfig: + return MixCoderConfig( + hook_point_in="in", + hook_point_out="out", + d_model=2, + expansion_factor=2, + device="cpu", + dtype=torch.bfloat16, # the precision of bfloat16 is not enough for the tests + act_fn="topk", + norm_activation="dataset-wise", + top_k=2, + modalities={"image": 2, "text": 2, "shared": 2}, + ) + + @pytest.fixture def initializer_config() -> InitializerConfig: return InitializerConfig( @@ -43,7 +58,7 @@ def trainer_config(tmp_path) -> TrainerConfig: # Remove tmp path os.rmdir(tmp_path) return TrainerConfig( - initial_k=2, + initial_k=3, total_training_tokens=400, log_frequency=10, eval_frequency=10, @@ -74,6 +89,7 @@ def test_train_sae( { "in": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), "out": torch.randn(4, 2, dtype=sae_config.dtype, device=sae_config.device), + "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=sae_config.device), } for _ in range(200) ] @@ -90,3 +106,59 @@ def test_train_sae( eval_fn=lambda x: None, wandb_logger=wandb_runner, ) + + +def test_train_mixcoder( + mixcoder_config: MixCoderConfig, + initializer_config: InitializerConfig, + trainer_config: TrainerConfig, + mocker: MockerFixture, + tmp_path, +) -> None: + wandb_runner = mocker.Mock() + wandb_runner.log = lambda *args, **kwargs: None + device_mesh = ( + init_device_mesh( + device_type="cuda", + mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1), + mesh_dim_names=("data", "model"), + ) + if os.environ.get("WORLD_SIZE") is not None + else None + ) + activation_stream = [ + { + "in": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), + "out": torch.randn(4, 2, dtype=mixcoder_config.dtype, device=mixcoder_config.device), + "tokens": torch.tensor([2, 3, 4, 5], dtype=torch.long, device=mixcoder_config.device), + } + for _ in range(200) + ] + initializer = Initializer(initializer_config) + tokenizer = mocker.Mock() + tokenizer.get_vocab.return_value = { + "IMGIMG1": 1, + "IMGIMG2": 2, + "IMGIMG3": 3, + "IMGIMG4": 4, + "TEXT1": 5, + "TEXT2": 6, + "TEXT3": 7, + "TEXT4": 8, + } + model_name = "facebook/chameleon-7b" + + mixcoder_settings = {"tokenizer": tokenizer, "model_name": model_name} + mixcoder = initializer.initialize_sae_from_config( + mixcoder_config, + device_mesh=device_mesh, + activation_stream=activation_stream, + mixcoder_settings=mixcoder_settings, + ) + trainer = Trainer(trainer_config) + trainer.fit( + sae=mixcoder, + activation_stream=activation_stream, + eval_fn=lambda x: None, + wandb_logger=wandb_runner, + ) diff --git a/tests/unit/test_evaluator.py b/tests/unit/test_evaluator.py index 61624d3b..dd969924 100644 --- a/tests/unit/test_evaluator.py +++ b/tests/unit/test_evaluator.py @@ -49,6 +49,7 @@ def stream_generator(): (batch_size, sae_config.d_model), device=sae_config.device, dtype=sae_config.dtype ) * (i + 1), + "tokens": torch.arange(start=1, end=5, device=sae_config.device, dtype=torch.int64), } return stream_generator() diff --git a/tests/unit/test_initializer.py b/tests/unit/test_initializer.py index ab2436af..ce600123 100644 --- a/tests/unit/test_initializer.py +++ b/tests/unit/test_initializer.py @@ -2,7 +2,7 @@ import torch from pytest_mock import MockerFixture -from lm_saes.config import InitializerConfig, SAEConfig +from lm_saes.config import InitializerConfig, MixCoderConfig, SAEConfig from lm_saes.initializer import Initializer @@ -19,10 +19,16 @@ def sae_config() -> SAEConfig: @pytest.fixture -def generator(sae_config: SAEConfig) -> torch.Generator: - gen = torch.Generator(device=sae_config.device) - gen.manual_seed(42) - return gen +def mixcoder_config() -> MixCoderConfig: + return MixCoderConfig( + hook_point_in="in", + hook_point_out="out", + d_model=2, + expansion_factor=2, + device="cpu", + dtype=torch.float32, # the precision of bfloat16 is not enough for the tests + modalities={"text": 4, "image": 4, "shared": 2}, + ) @pytest.fixture @@ -32,8 +38,11 @@ def initializer_config() -> InitializerConfig: ) -def test_initialize_sae_from_config(sae_config: SAEConfig, initializer_config: InitializerConfig): +def test_initialize_sae_from_config( + sae_config: SAEConfig, mixcoder_config: MixCoderConfig, initializer_config: InitializerConfig, mocker: MockerFixture +): initializer = Initializer(initializer_config) + initializer_config.state = "training" sae_config.norm_activation = "token-wise" sae = initializer.initialize_sae_from_config(sae_config) sae_config.norm_activation = "dataset-wise" @@ -50,10 +59,47 @@ def test_initialize_sae_from_config(sae_config: SAEConfig, initializer_config: I sae_config.act_fn = "topk" sae_config.jump_relu_threshold = 2.0 sae = initializer.initialize_sae_from_config(sae_config) + assert sae.cfg.act_fn == "jumprelu" + + initializer_config.state = "training" + tokenizer = mocker.Mock() + tokenizer.get_vocab.return_value = { + "IMGIMG1": 1, + "IMGIMG2": 2, + "IMGIMG3": 3, + "IMGIMG4": 4, + "TEXT1": 5, + "TEXT2": 6, + "TEXT3": 7, + "TEXT4": 8, + } + model_name = "facebook/chameleon-7b" + + mixcoder_settings = {"tokenizer": tokenizer, "model_name": model_name} + + mixcoder_config.norm_activation = "token-wise" + mixcoder = initializer.initialize_sae_from_config(mixcoder_config, mixcoder_settings=mixcoder_settings) + mixcoder_config.norm_activation = "dataset-wise" + mixcoder = initializer.initialize_sae_from_config( + mixcoder_config, activation_norm={"in": 3.0, "out": 2.0}, mixcoder_settings=mixcoder_settings + ) + assert mixcoder.dataset_average_activation_norm == {"in": 3.0, "out": 2.0} + + initializer_config.state = "inference" + mixcoder_config.norm_activation = "dataset-wise" + initializer = Initializer(initializer_config) + mixcoder = initializer.initialize_sae_from_config(mixcoder_config, activation_norm={"in": 3.0, "out": 2.0}) + assert mixcoder.cfg.norm_activation == "inference" + assert mixcoder.dataset_average_activation_norm == {"in": 3.0, "out": 2.0} + mixcoder_config.sparsity_include_decoder_norm = False + mixcoder_config.act_fn = "topk" + mixcoder_config.jump_relu_threshold = 2.0 + mixcoder = initializer.initialize_sae_from_config(mixcoder_config) + assert mixcoder.cfg.act_fn == "topk" def test_initialize_search( - mocker: MockerFixture, sae_config: SAEConfig, initializer_config: InitializerConfig, generator: torch.Generator + mocker: MockerFixture, sae_config: SAEConfig, mixcoder_config: MixCoderConfig, initializer_config: InitializerConfig ): def stream_generator(): # Create 10 batches of activations @@ -61,12 +107,12 @@ def stream_generator(): yield { "in": torch.ones(4, sae_config.d_model), # norm will be sqrt(16) "out": torch.ones(4, sae_config.d_model) * 2, # norm will be sqrt(16) * 2 + "tokens": torch.tensor([2, 3, 4, 5]), } sae_config.hook_point_out = sae_config.hook_point_in initializer_config.init_search = True initializer_config.l1_coefficient = 0.0008 - activation_stream_iter = mocker.Mock() activation_stream_iter = stream_generator() initializer = Initializer(initializer_config) sae = initializer.initialize_sae_from_config(sae_config, activation_stream=activation_stream_iter) @@ -77,3 +123,46 @@ def stream_generator(): sae_config.apply_decoder_bias_to_pre_encoder = False sae = initializer.initialize_sae_from_config(sae_config, activation_stream=activation_stream_iter) assert torch.allclose(sae._decoder_norm(sae.decoder), sae._decoder_norm(sae.decoder).mean(), atol=1e-4, rtol=1e-5) + + tokenizer = mocker.Mock() + tokenizer.get_vocab.return_value = { + "IMGIMG1": 1, + "IMGIMG2": 2, + "IMGIMG3": 3, + "IMGIMG4": 4, + "TEXT1": 5, + "TEXT2": 6, + "TEXT3": 7, + "TEXT4": 8, + } + model_name = "facebook/chameleon-7b" + + mixcoder_settings = {"tokenizer": tokenizer, "model_name": model_name} + + mixcoder_config.hook_point_out = mixcoder_config.hook_point_in + initializer_config.init_search = True + initializer_config.l1_coefficient = 0.0008 + activation_stream_iter = stream_generator() + initializer = Initializer(initializer_config) + mixcoder = initializer.initialize_sae_from_config( + mixcoder_config, mixcoder_settings=mixcoder_settings, activation_stream=activation_stream_iter + ) + assert torch.allclose( + mixcoder._decoder_norm(mixcoder.decoder["image"]), + mixcoder._decoder_norm(mixcoder.decoder["image"]).mean(), + atol=1e-4, + rtol=1e-5, + ) + + initializer_config.bias_init_method = "geometric_median" + initializer_config.init_encoder_with_decoder_transpose = True + mixcoder_config.apply_decoder_bias_to_pre_encoder = False + mixcoder = initializer.initialize_sae_from_config( + mixcoder_config, mixcoder_settings=mixcoder_settings, activation_stream=activation_stream_iter + ) + assert torch.allclose( + mixcoder._decoder_norm(mixcoder.decoder["text"]), + mixcoder._decoder_norm(mixcoder.decoder["text"]).mean(), + atol=1e-4, + rtol=1e-5, + ) diff --git a/tests/unit/test_mixcoder.py b/tests/unit/test_mixcoder.py new file mode 100644 index 00000000..5d342efe --- /dev/null +++ b/tests/unit/test_mixcoder.py @@ -0,0 +1,112 @@ +import pytest +import torch + +from lm_saes.config import MixCoderConfig +from lm_saes.mixcoder import MixCoder + + +@pytest.fixture +def config(): + return MixCoderConfig( + d_model=3, + modalities={"text": 2, "image": 3, "shared": 4}, + device="cpu", + dtype=torch.float32, + use_glu_encoder=False, + use_decoder_bias=True, + hook_point_in="hook_point_in", + hook_point_out="hook_point_out", + expansion_factor=1.0, + ) + + +@pytest.fixture +def modality_indices(): + return { + "text": torch.tensor([1, 2, 3, 4]), + "image": torch.tensor([5, 6, 7, 8]), + } + + +@pytest.fixture +def mixcoder(config, modality_indices): + model = MixCoder(config) + model.init_parameters(modality_indices=modality_indices) + return model + + +def test_init_parameters(mixcoder, config): + assert mixcoder.modality_index == {"text": (0, 2), "image": (2, 5), "shared": (5, 9)} + assert torch.allclose(mixcoder.modality_indices["text"], torch.tensor([1, 2, 3, 4])) + assert torch.allclose(mixcoder.modality_indices["image"], torch.tensor([5, 6, 7, 8])) + + +def test_encode_decode(mixcoder, config): + """Test the encoding and decoding process.""" + mixcoder.set_dataset_average_activation_norm({"hook_point_in": 1.0, "hook_point_out": 1.0}) + batch_size = 8 + x = torch.randn(batch_size, config.d_model) # batch, d_model + tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + x_text = torch.cat([x[:4, :], torch.zeros(4, config.d_model)], dim=0) + x_image = torch.cat([torch.zeros(4, config.d_model), x[4:, :]], dim=0) + tokens_text = torch.tensor([1, 2, 3, 4, 0, 0, 0, 0]) + tokens_image = torch.tensor([0, 0, 0, 0, 5, 6, 7, 8]) + # Test encode + feature_acts = mixcoder.encode(x, tokens=tokens) + assert feature_acts.shape == (batch_size, config.d_sae) # batch, d_sae + + feature_acts_text = mixcoder.encode(x_text, tokens=tokens_text) + assert feature_acts_text.shape == (batch_size, config.d_sae) + feature_acts_image = mixcoder.encode(x_image, tokens=tokens_image) + assert feature_acts_image.shape == (batch_size, config.d_sae) + modality_index = mixcoder.get_modality_index() + + assert torch.allclose( + feature_acts_text[:4, slice(*modality_index["text"])], + feature_acts[:4, slice(*modality_index["text"])], + ) + assert torch.allclose( + feature_acts_image[4:, slice(*modality_index["image"])], + feature_acts[4:, slice(*modality_index["image"])], + ) + + assert torch.allclose( + torch.cat( + [ + feature_acts_text[:4, slice(*modality_index["shared"])], + feature_acts_image[4:, slice(*modality_index["shared"])], + ], + dim=0, + ), + feature_acts[:, slice(*modality_index["shared"])], + ) + + # Test decode + reconstructed = mixcoder.decode(feature_acts) + assert reconstructed.shape == (batch_size, config.d_model) + + reconstructed_text = mixcoder.decode(feature_acts_text) + assert reconstructed_text.shape == (batch_size, config.d_model) + + reconstructed_image = mixcoder.decode(feature_acts_image) + assert reconstructed_image.shape == (batch_size, config.d_model) + + assert torch.allclose(reconstructed_text[:4, :], reconstructed[:4, :]) + assert torch.allclose(reconstructed_image[4:, :], reconstructed[4:, :]) + + +def test_get_modality_activation(mixcoder, config): + """Test the _get_modality_activation method.""" + batch_size = 8 + x = torch.ones(batch_size, config.d_model) + tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + + # Test text modality + text_activation = mixcoder._get_modality_activation(x, tokens, "text") + assert torch.all(text_activation[0, :4] == 1) # First 4 positions should be 1 + assert torch.all(text_activation[0, 4:] == 0) # Last 4 positions should be 0 + + # Test image modality + image_activation = mixcoder._get_modality_activation(x, tokens, "image") + assert torch.all(image_activation[1, :4] == 0) # First 4 positions should be 0 + assert torch.all(image_activation[1, 4:] == 1) # Last 4 positions should be 1