Skip to content
10 changes: 3 additions & 7 deletions src/lm_saes/mixcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,8 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]:
"""
return self.modality_index

def _get_modality_activation_mask(
def get_modality_token_mask(
self,
activation: Union[
Float[torch.Tensor, "batch d_model"],
Float[torch.Tensor, "batch seq_len d_model"],
],
tokens: Union[
Float[torch.Tensor, "batch d_model"],
Float[torch.Tensor, "batch seq_len d_model"],
Expand All @@ -182,7 +178,7 @@ def _get_modality_activation_mask(
The activation of the specified modality. The shape is the same as the input activation.
"""
activation_mask = torch.isin(tokens, self.modality_indices[modality])
return activation_mask.unsqueeze(1)
return activation_mask

@overload
def encode(
Expand Down Expand Up @@ -266,7 +262,7 @@ def encode(
if modality == "shared":
# shared modality is not encoded directly but summed up during other modalities' encoding
continue
activation_mask = self._get_modality_activation_mask(x, tokens, modality)
activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1)
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
modality_bias = (
self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore
Expand Down
17 changes: 15 additions & 2 deletions src/lm_saes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,26 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
if sae.cfg.sae_type == "mixcoder":
assert isinstance(sae, MixCoder)
for modality, (start, end) in sae.modality_index.items():
if modality == "shared":
continue
shared_start, shared_end = sae.modality_index["shared"]
mask = sae.get_modality_token_mask(batch["tokens"], modality)
token_num = mask.sum().item()
wandb_log_dict.update(
{
f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0)
f"l0_metrics/{modality}_l0": (log_info["feature_acts"][mask][:, start:end] > 0)
.float()
.sum(-1)
.mean()
.item(),
f"l0_metrics/{modality}_shared_l0": (
log_info["feature_acts"][mask][:, shared_start:shared_end] > 0
)
.float()
.sum(-1)
.mean()
.item(),
f"l0_metrics/{modality}_token_num": token_num,
}
)

Expand Down Expand Up @@ -260,7 +273,7 @@ def fit(
activation_in, activation_out = batch[sae.cfg.hook_point_in], batch[sae.cfg.hook_point_out]

if self.wandb_logger is not None:
self._log(sae, log_info, {"input": activation_in, "output": activation_out})
self._log(sae, log_info, {"input": activation_in, "output": activation_out, "tokens": batch["tokens"]})

if eval_fn is not None and (self.cur_step + 1) % self.cfg.eval_frequency == 0:
eval_fn(sae)
Expand Down
14 changes: 6 additions & 8 deletions tests/unit/test_mixcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,14 @@ def test_encode_decode(mixcoder, config):

def test_get_modality_activation_mask(mixcoder, config):
"""Test the _get_modality_activation method."""
batch_size = 8
x = torch.ones(batch_size, config.d_model)
tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])

# Test text modality
text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text")
assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1
assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0
text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text")
assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1
assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0

# Test image modality
image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image")
assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0
assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1
image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image")
assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0
assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1
Loading