diff --git a/src/lm_saes/__init__.py b/src/lm_saes/__init__.py index 48402b53..657a94b0 100644 --- a/src/lm_saes/__init__.py +++ b/src/lm_saes/__init__.py @@ -6,6 +6,7 @@ ActivationFactoryDatasetSource, ActivationFactoryTarget, ActivationWriterConfig, + CrossCoderConfig, DatasetConfig, FeatureAnalyzerConfig, InitializerConfig, @@ -29,6 +30,7 @@ __all__ = [ "ActivationFactory", "ActivationWriter", + "CrossCoderConfig", "LanguageModelConfig", "DatasetConfig", "ActivationFactoryActivationsSource", diff --git a/src/lm_saes/activation/processors/cached_activation.py b/src/lm_saes/activation/processors/cached_activation.py index bea63c5e..6fce82ef 100644 --- a/src/lm_saes/activation/processors/cached_activation.py +++ b/src/lm_saes/activation/processors/cached_activation.py @@ -265,10 +265,10 @@ def _process_chunks(self, hook_chunks: dict[str, list[ChunkInfo]], total_chunks: done, futures = wait(futures, return_when=FIRST_COMPLETED) pbar.set_postfix({"Active chunks": len(futures)}) # Process completed chunks in order - for future in tqdm(done, desc="Processing chunks", smoothing=0.001, leave=False): + for future in tqdm(done, desc="Processing chunks", smoothing=0.001, leave=False, disable=True): chunk_data = future.result() chunk_data = { - k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items() + k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items() } yield chunk_data pbar.update(1) diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index d409a3ec..85a6748b 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -57,7 +57,7 @@ class BaseSAEConfig(BaseModelConfig): use_glu_encoder: bool = False act_fn: Literal["relu", "jumprelu", "topk", "batchtopk"] = "relu" jump_relu_threshold: float = 0.0 - apply_decoder_bias_to_pre_encoder: bool = True + apply_decoder_bias_to_pre_encoder: bool = False norm_activation: str = "dataset-wise" sparsity_include_decoder_norm: bool = True top_k: int = 50 diff --git a/src/lm_saes/crosscoder.py b/src/lm_saes/crosscoder.py index 75a53c5d..17747ed6 100644 --- a/src/lm_saes/crosscoder.py +++ b/src/lm_saes/crosscoder.py @@ -203,7 +203,10 @@ def compute_loss( } # l_l1: (batch,) - feature_acts = feature_acts * self.decoder_norm(local_only=False, aggregate="mean") + feature_acts = feature_acts * self._decoder_norm( + decoder=self.decoder, + local_only=True, + ) if "topk" not in self.cfg.act_fn: l_lp = torch.norm(feature_acts, p=lp, dim=-1) @@ -224,6 +227,10 @@ def compute_loss( return loss + @torch.no_grad() + def log_statistics(self): + return {} + def initialize_with_same_weight_across_layers(self): self.encoder.weight.data = get_tensor_from_specific_rank(self.encoder.weight.data.clone(), src=0) self.encoder.bias.data = get_tensor_from_specific_rank(self.encoder.bias.data.clone(), src=0) diff --git a/src/lm_saes/initializer.py b/src/lm_saes/initializer.py index fcae98b6..215a8ce5 100644 --- a/src/lm_saes/initializer.py +++ b/src/lm_saes/initializer.py @@ -12,6 +12,7 @@ ) from lm_saes.config import BaseSAEConfig, InitializerConfig +from lm_saes.crosscoder import CrossCoder from lm_saes.mixcoder import MixCoder from lm_saes.sae import SparseAutoEncoder from lm_saes.utils.misc import calculate_activation_norm, get_modality_indices @@ -161,9 +162,11 @@ def initialize_sae_from_config( sae = SparseAutoEncoder.from_config(cfg) elif cfg.sae_type == "mixcoder": sae = MixCoder.from_config(cfg) + elif cfg.sae_type == "crosscoder": + sae = CrossCoder.from_config(cfg) else: # TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc. - pass + raise ValueError(f'SAE type {cfg.sae_type} not supported.') if self.cfg.state == "training": if cfg.sae_pretrained_name_or_path is None: sae: SparseAutoEncoder = self.initialize_parameters(sae, mixcoder_settings=mixcoder_settings) diff --git a/src/lm_saes/runner.py b/src/lm_saes/runner.py index a5acc945..c16f3e68 100644 --- a/src/lm_saes/runner.py +++ b/src/lm_saes/runner.py @@ -15,12 +15,12 @@ ActivationFactoryDatasetSource, ActivationFactoryTarget, ActivationWriterConfig, + BaseSAEConfig, DatasetConfig, FeatureAnalyzerConfig, InitializerConfig, LanguageModelConfig, MongoDBConfig, - SAEConfig, TrainerConfig, WandbConfig, ) @@ -223,7 +223,7 @@ def generate_activations(settings: GenerateActivationsSettings) -> None: class TrainSAESettings(BaseSettings): """Settings for training a Sparse Autoencoder (SAE).""" - sae: SAEConfig + sae: BaseSAEConfig """Configuration for the SAE model architecture and parameters""" sae_name: str @@ -272,6 +272,7 @@ def train_sae(settings: TrainSAESettings) -> None: if settings.data_parallel_size > 1 or settings.model_parallel_size > 1 else None ) + activation_factory = ActivationFactory(settings.activation_factory) activations_stream = activation_factory.process() initializer = Initializer(settings.initializer) @@ -288,8 +289,7 @@ def train_sae(settings: TrainSAESettings) -> None: settings=wandb.Settings(x_disable_stats=True), mode=os.getenv("WANDB_MODE", "online"), ) - if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) - else None + if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) else None ) if wandb_logger is not None: wandb_logger.watch(sae, log="all") @@ -310,7 +310,7 @@ def train_sae(settings: TrainSAESettings) -> None: class AnalyzeSAESettings(BaseSettings): - sae: SAEConfig + sae: BaseSAEConfig """Configuration for the SAE model architecture and parameters""" sae_name: str diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 0430907d..f59022a7 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -2,11 +2,10 @@ import os from importlib.metadata import version from pathlib import Path -from typing import Callable, Literal, Union, cast, overload +from typing import Any, Callable, Literal, Union, cast, overload import safetensors.torch as safe import torch -from fsspec.spec import Any from jaxtyping import Float from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor @@ -110,6 +109,7 @@ def topk_activation(x: torch.Tensor): k = x.shape[-1] - self.current_k + 1 k_th_value, _ = torch.kthvalue(x, k=k, dim=-1) k_th_value = k_th_value.unsqueeze(dim=1) + print() return x.ge(k_th_value) return topk_activation diff --git a/src/lm_saes/trainer.py b/src/lm_saes/trainer.py index 4b554cf5..1a75fc17 100644 --- a/src/lm_saes/trainer.py +++ b/src/lm_saes/trainer.py @@ -12,6 +12,7 @@ from lm_saes.config import TrainerConfig from lm_saes.optim import get_scheduler from lm_saes.sae import SparseAutoEncoder +from lm_saes.utils.misc import all_reduce_tensor class Trainer: @@ -85,7 +86,7 @@ def _training_step( sae: SparseAutoEncoder, batch: dict[str, Tensor], ) -> dict[str, Tensor]: - if (not sae.cfg.act_fn == "topk") and self.l1_coefficient_warmup_steps > 0: + if "topk" not in sae.cfg.act_fn and self.l1_coefficient_warmup_steps > 0: assert self.cfg.l1_coefficient is not None sae.set_current_l1_coefficient( min(1.0, self.cur_step / self.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient @@ -133,6 +134,14 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) "sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(), "sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(), } + if sae.cfg.sae_type == 'crosscoder': + wandb_log_dict.update({ + "sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-1).sum().item(), + "sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-2).sum().item(), + "sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-5).sum().item(), + "sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-6).sum().item(), + }) + self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1) log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"]) log_info["n_frac_active_tokens"] = torch.zeros_like(log_info["n_frac_active_tokens"]) diff --git a/tests/unit/test_sae.py b/tests/unit/test_sae.py index 593a90cc..a32e4fc6 100644 --- a/tests/unit/test_sae.py +++ b/tests/unit/test_sae.py @@ -229,4 +229,4 @@ def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder): {"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)} ) output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype)) - assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype)) + assert torch.allclose(output, torch.tensor([[212.0, 449.0]], device=sae_config.device, dtype=sae_config.dtype))