diff --git a/src/lm_saes/mixcoder.py b/src/lm_saes/mixcoder.py index eda771d4..69909455 100644 --- a/src/lm_saes/mixcoder.py +++ b/src/lm_saes/mixcoder.py @@ -159,12 +159,8 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]: """ return self.modality_index - def _get_modality_activation_mask( + def get_modality_token_mask( self, - activation: Union[ - Float[torch.Tensor, "batch d_model"], - Float[torch.Tensor, "batch seq_len d_model"], - ], tokens: Union[ Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], @@ -182,7 +178,7 @@ def _get_modality_activation_mask( The activation of the specified modality. The shape is the same as the input activation. """ activation_mask = torch.isin(tokens, self.modality_indices[modality]) - return activation_mask.unsqueeze(1) + return activation_mask @overload def encode( @@ -266,7 +262,7 @@ def encode( if modality == "shared": # shared modality is not encoded directly but summed up during other modalities' encoding continue - activation_mask = self._get_modality_activation_mask(x, tokens, modality) + activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1) if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: modality_bias = ( self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index 28a78a77..ba7d355b 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -187,13 +187,26 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) if sae.cfg.sae_type == "mixcoder": assert isinstance(sae, MixCoder) for modality, (start, end) in sae.modality_index.items(): + if modality == "shared": + continue + shared_start, shared_end = sae.modality_index["shared"] + mask = sae.get_modality_token_mask(batch["tokens"], modality) + token_num = mask.sum().item() wandb_log_dict.update( { - f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0) + f"l0_metrics/{modality}_l0": (log_info["feature_acts"][mask][:, start:end] > 0) .float() .sum(-1) .mean() .item(), + f"l0_metrics/{modality}_shared_l0": ( + log_info["feature_acts"][mask][:, shared_start:shared_end] > 0 + ) + .float() + .sum(-1) + .mean() + .item(), + f"l0_metrics/{modality}_token_num": token_num, } ) @@ -260,7 +273,7 @@ def fit( activation_in, activation_out = batch[sae.cfg.hook_point_in], batch[sae.cfg.hook_point_out] if self.wandb_logger is not None: - self._log(sae, log_info, {"input": activation_in, "output": activation_out}) + self._log(sae, log_info, {"input": activation_in, "output": activation_out, "tokens": batch["tokens"]}) if eval_fn is not None and (self.cur_step + 1) % self.cfg.eval_frequency == 0: eval_fn(sae) diff --git a/tests/unit/test_mixcoder.py b/tests/unit/test_mixcoder.py index 41308e53..4aaa051b 100644 --- a/tests/unit/test_mixcoder.py +++ b/tests/unit/test_mixcoder.py @@ -106,16 +106,14 @@ def test_encode_decode(mixcoder, config): def test_get_modality_activation_mask(mixcoder, config): """Test the _get_modality_activation method.""" - batch_size = 8 - x = torch.ones(batch_size, config.d_model) tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Test text modality - text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text") - assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1 - assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0 + text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text") + assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1 + assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0 # Test image modality - image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image") - assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0 - assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1 + image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image") + assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0 + assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1