From 4f5b8cafa5bb5ac01783cbc1f372b0b50bfffd4e Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 13:21:14 +0800 Subject: [PATCH 1/7] fix(runner): support mixcoder training --- src/lm_saes/__init__.py | 2 ++ src/lm_saes/runner.py | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/__init__.py b/src/lm_saes/__init__.py index 657a94b0..0f78b5a2 100644 --- a/src/lm_saes/__init__.py +++ b/src/lm_saes/__init__.py @@ -11,6 +11,7 @@ FeatureAnalyzerConfig, InitializerConfig, LanguageModelConfig, + MixCoderConfig, MongoDBConfig, SAEConfig, TrainerConfig, @@ -54,4 +55,5 @@ "FeatureAnalyzerConfig", "MongoDBConfig", "MongoClient", + "MixCoderConfig", ] diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index c16f3e68..8cacdb7b 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -6,6 +6,7 @@ from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoTokenizer from lm_saes.activation.factory import ActivationFactory from lm_saes.activation.writer import ActivationWriter @@ -256,6 +257,9 @@ class TrainSAESettings(BaseSettings): mongo: Optional[MongoDBConfig] = None """Configuration for MongoDB""" + model_name: Optional[str] = None + """Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices.""" + def train_sae(settings: TrainSAESettings) -> None: """Train a SAE model. @@ -276,9 +280,24 @@ def train_sae(settings: TrainSAESettings) -> None: activation_factory = ActivationFactory(settings.activation_factory) activations_stream = activation_factory.process() initializer = Initializer(settings.initializer) - sae = initializer.initialize_sae_from_config( - settings.sae, activation_stream=activations_stream, device_mesh=device_mesh - ) + + if settings.sae.sae_type == "mixcoder": + assert settings.model_name is not None, "Model name is required for mixcoder SAE" + tokenizer = AutoTokenizer.from_pretrained(settings.model_name, trust_remote_code=True) + mixcoder_settings = { + "model_name": settings.model_name, + "tokenizer": tokenizer, + } + sae = initializer.initialize_sae_from_config( + settings.sae, + activation_stream=activations_stream, + device_mesh=device_mesh, + mixcoder_settings=mixcoder_settings, + ) + else: + sae = initializer.initialize_sae_from_config( + settings.sae, activation_stream=activations_stream, device_mesh=device_mesh + ) wandb_logger = ( wandb.init( @@ -289,7 +308,8 @@ def train_sae(settings: TrainSAESettings) -> None: settings=wandb.Settings(x_disable_stats=True), mode=os.getenv("WANDB_MODE", "online"), ) - if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) else None + if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) + else None ) if wandb_logger is not None: wandb_logger.watch(sae, log="all") From 47aaace396499553b43dd6b458a32e36c148b2f1 Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 13:21:14 +0800 Subject: [PATCH 2/7] fix(runner): support mixcoder training --- src/lm_saes/__init__.py | 2 ++ src/lm_saes/runner.py | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/lm_saes/__init__.py b/src/lm_saes/__init__.py index 657a94b0..0f78b5a2 100644 --- a/src/lm_saes/__init__.py +++ b/src/lm_saes/__init__.py @@ -11,6 +11,7 @@ FeatureAnalyzerConfig, InitializerConfig, LanguageModelConfig, + MixCoderConfig, MongoDBConfig, SAEConfig, TrainerConfig, @@ -54,4 +55,5 @@ "FeatureAnalyzerConfig", "MongoDBConfig", "MongoClient", + "MixCoderConfig", ] diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index c16f3e68..8cacdb7b 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -6,6 +6,7 @@ from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoTokenizer from lm_saes.activation.factory import ActivationFactory from lm_saes.activation.writer import ActivationWriter @@ -256,6 +257,9 @@ class TrainSAESettings(BaseSettings): mongo: Optional[MongoDBConfig] = None """Configuration for MongoDB""" + model_name: Optional[str] = None + """Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices.""" + def train_sae(settings: TrainSAESettings) -> None: """Train a SAE model. @@ -276,9 +280,24 @@ def train_sae(settings: TrainSAESettings) -> None: activation_factory = ActivationFactory(settings.activation_factory) activations_stream = activation_factory.process() initializer = Initializer(settings.initializer) - sae = initializer.initialize_sae_from_config( - settings.sae, activation_stream=activations_stream, device_mesh=device_mesh - ) + + if settings.sae.sae_type == "mixcoder": + assert settings.model_name is not None, "Model name is required for mixcoder SAE" + tokenizer = AutoTokenizer.from_pretrained(settings.model_name, trust_remote_code=True) + mixcoder_settings = { + "model_name": settings.model_name, + "tokenizer": tokenizer, + } + sae = initializer.initialize_sae_from_config( + settings.sae, + activation_stream=activations_stream, + device_mesh=device_mesh, + mixcoder_settings=mixcoder_settings, + ) + else: + sae = initializer.initialize_sae_from_config( + settings.sae, activation_stream=activations_stream, device_mesh=device_mesh + ) wandb_logger = ( wandb.init( @@ -289,7 +308,8 @@ def train_sae(settings: TrainSAESettings) -> None: settings=wandb.Settings(x_disable_stats=True), mode=os.getenv("WANDB_MODE", "online"), ) - if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) else None + if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) + else None ) if wandb_logger is not None: wandb_logger.watch(sae, log="all") From 722ebf2345ee062bf149e32f817abe5001c7dcda Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 14:30:37 +0800 Subject: [PATCH 3/7] fix(mixcoder): fix topk activation func --- src/lm_saes/mixcoder.py | 21 ++++++++++++--------- src/lm_saes/sae.py | 1 - tests/unit/test_mixcoder.py | 25 +++++++++++++++++-------- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/lm_saes/mixcoder.py b/src/lm_saes/mixcoder.py index 68777946..eda771d4 100644 --- a/src/lm_saes/mixcoder.py +++ b/src/lm_saes/mixcoder.py @@ -159,7 +159,7 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]: """ return self.modality_index - def _get_modality_activation( + def _get_modality_activation_mask( self, activation: Union[ Float[torch.Tensor, "batch d_model"], @@ -182,7 +182,7 @@ def _get_modality_activation( The activation of the specified modality. The shape is the same as the input activation. """ activation_mask = torch.isin(tokens, self.modality_indices[modality]) - return activation_mask.unsqueeze(1) * activation + return activation_mask.unsqueeze(1) @overload def encode( @@ -266,7 +266,7 @@ def encode( if modality == "shared": # shared modality is not encoded directly but summed up during other modalities' encoding continue - x_modality = self._get_modality_activation(x, tokens, modality) + activation_mask = self._get_modality_activation_mask(x, tokens, modality) if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: modality_bias = ( self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore @@ -278,15 +278,15 @@ def encode( if isinstance(self.decoder["shared"].bias, DTensor) else self.decoder["shared"].bias ) - x_modality = x_modality - modality_bias - shared_bias + x = x - modality_bias - shared_bias - hidden_pre_modality = self.encoder[modality](x_modality) - hidden_pre_shared = self.encoder["shared"](x_modality) + hidden_pre_modality = self.encoder[modality](x) + hidden_pre_shared = self.encoder["shared"](x) if self.cfg.use_glu_encoder: - hidden_pre_modality_glu = torch.sigmoid(self.encoder_glu[modality](x_modality)) + hidden_pre_modality_glu = torch.sigmoid(self.encoder_glu[modality](x)) hidden_pre_modality = hidden_pre_modality * hidden_pre_modality_glu - hidden_pre_shared_glu = torch.sigmoid(self.encoder_glu["shared"](x_modality)) + hidden_pre_shared_glu = torch.sigmoid(self.encoder_glu["shared"](x)) hidden_pre_shared = hidden_pre_shared * hidden_pre_shared_glu if self.cfg.sparsity_include_decoder_norm: @@ -296,7 +296,9 @@ def encode( true_feature_acts_modality = hidden_pre_modality true_feature_acts_shared = hidden_pre_shared - true_feature_acts_concat = torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) + true_feature_acts_concat = ( + torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) * activation_mask + ) activation_mask_concat = self.activation_function(true_feature_acts_concat) feature_acts_concat = true_feature_acts_concat * activation_mask_concat @@ -313,6 +315,7 @@ def encode( hidden_pre = self.hook_hidden_pre(hidden_pre) feature_acts = self.hook_feature_acts(feature_acts) + # assert torch.all((feature_acts > 0).sum(-1) <= self.current_k) if return_hidden_pre: return feature_acts, hidden_pre return feature_acts diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index f59022a7..8148a3a7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -109,7 +109,6 @@ def topk_activation(x: torch.Tensor): k = x.shape[-1] - self.current_k + 1 k_th_value, _ = torch.kthvalue(x, k=k, dim=-1) k_th_value = k_th_value.unsqueeze(dim=1) - print() return x.ge(k_th_value) return topk_activation diff --git a/tests/unit/test_mixcoder.py b/tests/unit/test_mixcoder.py index 5d342efe..41308e53 100644 --- a/tests/unit/test_mixcoder.py +++ b/tests/unit/test_mixcoder.py @@ -12,11 +12,13 @@ def config(): modalities={"text": 2, "image": 3, "shared": 4}, device="cpu", dtype=torch.float32, - use_glu_encoder=False, + use_glu_encoder=True, use_decoder_bias=True, hook_point_in="hook_point_in", hook_point_out="hook_point_out", expansion_factor=1.0, + top_k=2, + act_fn="topk", ) @@ -32,6 +34,12 @@ def modality_indices(): def mixcoder(config, modality_indices): model = MixCoder(config) model.init_parameters(modality_indices=modality_indices) + model.decoder["text"].bias.data = torch.rand_like(model.decoder["text"].bias.data) + model.decoder["image"].bias.data = torch.rand_like(model.decoder["image"].bias.data) + model.decoder["shared"].bias.data = torch.rand_like(model.decoder["shared"].bias.data) + model.encoder["text"].bias.data = torch.rand_like(model.encoder["text"].bias.data) + model.encoder["image"].bias.data = torch.rand_like(model.encoder["image"].bias.data) + model.encoder["shared"].bias.data = torch.rand_like(model.encoder["shared"].bias.data) return model @@ -80,6 +88,7 @@ def test_encode_decode(mixcoder, config): ), feature_acts[:, slice(*modality_index["shared"])], ) + print(feature_acts) # Test decode reconstructed = mixcoder.decode(feature_acts) @@ -95,18 +104,18 @@ def test_encode_decode(mixcoder, config): assert torch.allclose(reconstructed_image[4:, :], reconstructed[4:, :]) -def test_get_modality_activation(mixcoder, config): +def test_get_modality_activation_mask(mixcoder, config): """Test the _get_modality_activation method.""" batch_size = 8 x = torch.ones(batch_size, config.d_model) tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Test text modality - text_activation = mixcoder._get_modality_activation(x, tokens, "text") - assert torch.all(text_activation[0, :4] == 1) # First 4 positions should be 1 - assert torch.all(text_activation[0, 4:] == 0) # Last 4 positions should be 0 + text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text") + assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1 + assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0 # Test image modality - image_activation = mixcoder._get_modality_activation(x, tokens, "image") - assert torch.all(image_activation[1, :4] == 0) # First 4 positions should be 0 - assert torch.all(image_activation[1, 4:] == 1) # Last 4 positions should be 1 + image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image") + assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0 + assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1 From ba95de8c47f5d03b63535c2343daf233f35e1ada Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 15:22:33 +0800 Subject: [PATCH 4/7] feat(trainer): add extra log info for mixcoder --- src/lm_saes/trainer.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index 1a75fc17..28a78a77 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -10,6 +10,7 @@ from wandb.sdk.wandb_run import Run from lm_saes.config import TrainerConfig +from lm_saes.mixcoder import MixCoder from lm_saes.optim import get_scheduler from lm_saes.sae import SparseAutoEncoder from lm_saes.utils.misc import all_reduce_tensor @@ -134,13 +135,23 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) "sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(), "sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(), } - if sae.cfg.sae_type == 'crosscoder': - wandb_log_dict.update({ - "sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-1).sum().item(), - "sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-2).sum().item(), - "sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-5).sum().item(), - "sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-6).sum().item(), - }) + if sae.cfg.sae_type == "crosscoder": + wandb_log_dict.update( + { + "sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-1) + .sum() + .item(), + "sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-2) + .sum() + .item(), + "sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-5) + .sum() + .item(), + "sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-6) + .sum() + .item(), + } + ) self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1) log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"]) @@ -173,6 +184,19 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) "details/n_training_tokens": self.cur_tokens, } wandb_log_dict.update(sae.log_statistics()) + if sae.cfg.sae_type == "mixcoder": + assert isinstance(sae, MixCoder) + for modality, (start, end) in sae.modality_index.items(): + wandb_log_dict.update( + { + f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0) + .float() + .sum(-1) + .mean() + .item(), + } + ) + self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1) def _save_checkpoint(self, sae: SparseAutoEncoder): From 0ca8f7ec9d4cc313dd6b0890cf5db8b8a53d11ec Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 23:18:52 +0800 Subject: [PATCH 5/7] feat(trainer): add some new extra log info for mixcoder --- src/lm_saes/mixcoder.py | 10 +++------- src/lm_saes/trainer.py | 15 ++++++++++++++- tests/unit/test_mixcoder.py | 14 ++++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/lm_saes/mixcoder.py b/src/lm_saes/mixcoder.py index eda771d4..69909455 100644 --- a/src/lm_saes/mixcoder.py +++ b/src/lm_saes/mixcoder.py @@ -159,12 +159,8 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]: """ return self.modality_index - def _get_modality_activation_mask( + def get_modality_token_mask( self, - activation: Union[ - Float[torch.Tensor, "batch d_model"], - Float[torch.Tensor, "batch seq_len d_model"], - ], tokens: Union[ Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], @@ -182,7 +178,7 @@ def _get_modality_activation_mask( The activation of the specified modality. The shape is the same as the input activation. """ activation_mask = torch.isin(tokens, self.modality_indices[modality]) - return activation_mask.unsqueeze(1) + return activation_mask @overload def encode( @@ -266,7 +262,7 @@ def encode( if modality == "shared": # shared modality is not encoded directly but summed up during other modalities' encoding continue - activation_mask = self._get_modality_activation_mask(x, tokens, modality) + activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1) if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: modality_bias = ( self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index 28a78a77..a1bc7339 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -187,13 +187,26 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) if sae.cfg.sae_type == "mixcoder": assert isinstance(sae, MixCoder) for modality, (start, end) in sae.modality_index.items(): + if modality == "shared": + continue + shared_start, shared_end = sae.modality_index["shared"] + mask = sae.get_modality_token_mask(batch["tokens"], modality) + token_num = mask.sum().item() wandb_log_dict.update( { - f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0) + f"l0_metrics/{modality}_l0": (log_info["feature_acts"][mask][:, start:end] > 0) .float() .sum(-1) .mean() .item(), + f"l0_metrics/{modality}_shared_l0": ( + log_info["feature_acts"][mask][:, shared_start:shared_end] > 0 + ) + .float() + .sum(-1) + .mean() + .item(), + f"l0_metrics/{modality}_token_num": token_num, } ) diff --git a/tests/unit/test_mixcoder.py b/tests/unit/test_mixcoder.py index 41308e53..4aaa051b 100644 --- a/tests/unit/test_mixcoder.py +++ b/tests/unit/test_mixcoder.py @@ -106,16 +106,14 @@ def test_encode_decode(mixcoder, config): def test_get_modality_activation_mask(mixcoder, config): """Test the _get_modality_activation method.""" - batch_size = 8 - x = torch.ones(batch_size, config.d_model) tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Test text modality - text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text") - assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1 - assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0 + text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text") + assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1 + assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0 # Test image modality - image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image") - assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0 - assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1 + image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image") + assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0 + assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1 From f68e3fa35d86b47043de8c7214b5228e0efa07b9 Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 23:18:52 +0800 Subject: [PATCH 6/7] feat(trainer): add some new extra log info for mixcoder --- src/lm_saes/mixcoder.py | 10 +++------- src/lm_saes/trainer.py | 15 ++++++++++++++- tests/unit/test_mixcoder.py | 14 ++++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/lm_saes/mixcoder.py b/src/lm_saes/mixcoder.py index eda771d4..69909455 100644 --- a/src/lm_saes/mixcoder.py +++ b/src/lm_saes/mixcoder.py @@ -159,12 +159,8 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]: """ return self.modality_index - def _get_modality_activation_mask( + def get_modality_token_mask( self, - activation: Union[ - Float[torch.Tensor, "batch d_model"], - Float[torch.Tensor, "batch seq_len d_model"], - ], tokens: Union[ Float[torch.Tensor, "batch d_model"], Float[torch.Tensor, "batch seq_len d_model"], @@ -182,7 +178,7 @@ def _get_modality_activation_mask( The activation of the specified modality. The shape is the same as the input activation. """ activation_mask = torch.isin(tokens, self.modality_indices[modality]) - return activation_mask.unsqueeze(1) + return activation_mask @overload def encode( @@ -266,7 +262,7 @@ def encode( if modality == "shared": # shared modality is not encoded directly but summed up during other modalities' encoding continue - activation_mask = self._get_modality_activation_mask(x, tokens, modality) + activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1) if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: modality_bias = ( self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index 28a78a77..a1bc7339 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -187,13 +187,26 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) if sae.cfg.sae_type == "mixcoder": assert isinstance(sae, MixCoder) for modality, (start, end) in sae.modality_index.items(): + if modality == "shared": + continue + shared_start, shared_end = sae.modality_index["shared"] + mask = sae.get_modality_token_mask(batch["tokens"], modality) + token_num = mask.sum().item() wandb_log_dict.update( { - f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0) + f"l0_metrics/{modality}_l0": (log_info["feature_acts"][mask][:, start:end] > 0) .float() .sum(-1) .mean() .item(), + f"l0_metrics/{modality}_shared_l0": ( + log_info["feature_acts"][mask][:, shared_start:shared_end] > 0 + ) + .float() + .sum(-1) + .mean() + .item(), + f"l0_metrics/{modality}_token_num": token_num, } ) diff --git a/tests/unit/test_mixcoder.py b/tests/unit/test_mixcoder.py index 41308e53..4aaa051b 100644 --- a/tests/unit/test_mixcoder.py +++ b/tests/unit/test_mixcoder.py @@ -106,16 +106,14 @@ def test_encode_decode(mixcoder, config): def test_get_modality_activation_mask(mixcoder, config): """Test the _get_modality_activation method.""" - batch_size = 8 - x = torch.ones(batch_size, config.d_model) tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Test text modality - text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text") - assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1 - assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0 + text_activation_mask = mixcoder.get_modality_token_mask(tokens, "text") + assert torch.all(text_activation_mask[:4] == 1) # First 4 positions should be 1 + assert torch.all(text_activation_mask[4:] == 0) # Last 4 positions should be 0 # Test image modality - image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image") - assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0 - assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1 + image_activation_mask = mixcoder.get_modality_token_mask(tokens, "image") + assert torch.all(image_activation_mask[:4] == 0) # First 4 positions should be 0 + assert torch.all(image_activation_mask[4:] == 1) # Last 4 positions should be 1 From 948f52052ee4ea2c025b943207ac7b134827a8fc Mon Sep 17 00:00:00 2001 From: frankstein Date: Thu, 23 Jan 2025 23:38:24 +0800 Subject: [PATCH 7/7] feat(trainer): add some new extra log info for mixcoder --- src/lm_saes/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index a1bc7339..ba7d355b 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -273,7 +273,7 @@ def fit( activation_in, activation_out = batch[sae.cfg.hook_point_in], batch[sae.cfg.hook_point_out] if self.wandb_logger is not None: - self._log(sae, log_info, {"input": activation_in, "output": activation_out}) + self._log(sae, log_info, {"input": activation_in, "output": activation_out, "tokens": batch["tokens"]}) if eval_fn is not None and (self.cur_step + 1) % self.cfg.eval_frequency == 0: eval_fn(sae)