Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lm_saes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
ActivationWriterConfig,
CrossCoderConfig,
DatasetConfig,
FeatureAnalyzerConfig,
InitializerConfig,
Expand All @@ -29,6 +30,7 @@
__all__ = [
"ActivationFactory",
"ActivationWriter",
"CrossCoderConfig",
"LanguageModelConfig",
"DatasetConfig",
"ActivationFactoryActivationsSource",
Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/activation/processors/cached_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,10 @@ def _process_chunks(self, hook_chunks: dict[str, list[ChunkInfo]], total_chunks:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
pbar.set_postfix({"Active chunks": len(futures)})
# Process completed chunks in order
for future in tqdm(done, desc="Processing chunks", smoothing=0.001, leave=False):
for future in tqdm(done, desc="Processing chunks", smoothing=0.001, leave=False, disable=True):
chunk_data = future.result()
chunk_data = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items()
k: v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in chunk_data.items()
}
yield chunk_data
pbar.update(1)
Expand Down
2 changes: 1 addition & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BaseSAEConfig(BaseModelConfig):
use_glu_encoder: bool = False
act_fn: Literal["relu", "jumprelu", "topk", "batchtopk"] = "relu"
jump_relu_threshold: float = 0.0
apply_decoder_bias_to_pre_encoder: bool = True
apply_decoder_bias_to_pre_encoder: bool = False
norm_activation: str = "dataset-wise"
sparsity_include_decoder_norm: bool = True
top_k: int = 50
Expand Down
9 changes: 8 additions & 1 deletion src/lm_saes/crosscoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def compute_loss(
}

# l_l1: (batch,)
feature_acts = feature_acts * self.decoder_norm(local_only=False, aggregate="mean")
feature_acts = feature_acts * self._decoder_norm(
decoder=self.decoder,
local_only=True,
)

if "topk" not in self.cfg.act_fn:
l_lp = torch.norm(feature_acts, p=lp, dim=-1)
Expand All @@ -224,6 +227,10 @@ def compute_loss(

return loss

@torch.no_grad()
def log_statistics(self):
return {}

def initialize_with_same_weight_across_layers(self):
self.encoder.weight.data = get_tensor_from_specific_rank(self.encoder.weight.data.clone(), src=0)
self.encoder.bias.data = get_tensor_from_specific_rank(self.encoder.bias.data.clone(), src=0)
Expand Down
5 changes: 4 additions & 1 deletion src/lm_saes/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

from lm_saes.config import BaseSAEConfig, InitializerConfig
from lm_saes.crosscoder import CrossCoder
from lm_saes.mixcoder import MixCoder
from lm_saes.sae import SparseAutoEncoder
from lm_saes.utils.misc import calculate_activation_norm, get_modality_indices
Expand Down Expand Up @@ -161,9 +162,11 @@ def initialize_sae_from_config(
sae = SparseAutoEncoder.from_config(cfg)
elif cfg.sae_type == "mixcoder":
sae = MixCoder.from_config(cfg)
elif cfg.sae_type == "crosscoder":
sae = CrossCoder.from_config(cfg)
else:
# TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
pass
raise ValueError(f'SAE type {cfg.sae_type} not supported.')
if self.cfg.state == "training":
if cfg.sae_pretrained_name_or_path is None:
sae: SparseAutoEncoder = self.initialize_parameters(sae, mixcoder_settings=mixcoder_settings)
Expand Down
10 changes: 5 additions & 5 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
ActivationWriterConfig,
BaseSAEConfig,
DatasetConfig,
FeatureAnalyzerConfig,
InitializerConfig,
LanguageModelConfig,
MongoDBConfig,
SAEConfig,
TrainerConfig,
WandbConfig,
)
Expand Down Expand Up @@ -223,7 +223,7 @@ def generate_activations(settings: GenerateActivationsSettings) -> None:
class TrainSAESettings(BaseSettings):
"""Settings for training a Sparse Autoencoder (SAE)."""

sae: SAEConfig
sae: BaseSAEConfig
"""Configuration for the SAE model architecture and parameters"""

sae_name: str
Expand Down Expand Up @@ -272,6 +272,7 @@ def train_sae(settings: TrainSAESettings) -> None:
if settings.data_parallel_size > 1 or settings.model_parallel_size > 1
else None
)

activation_factory = ActivationFactory(settings.activation_factory)
activations_stream = activation_factory.process()
initializer = Initializer(settings.initializer)
Expand All @@ -288,8 +289,7 @@ def train_sae(settings: TrainSAESettings) -> None:
settings=wandb.Settings(x_disable_stats=True),
mode=os.getenv("WANDB_MODE", "online"),
)
if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0)
else None
if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0) else None
)
if wandb_logger is not None:
wandb_logger.watch(sae, log="all")
Expand All @@ -310,7 +310,7 @@ def train_sae(settings: TrainSAESettings) -> None:


class AnalyzeSAESettings(BaseSettings):
sae: SAEConfig
sae: BaseSAEConfig
"""Configuration for the SAE model architecture and parameters"""

sae_name: str
Expand Down
4 changes: 2 additions & 2 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import os
from importlib.metadata import version
from pathlib import Path
from typing import Callable, Literal, Union, cast, overload
from typing import Any, Callable, Literal, Union, cast, overload

import safetensors.torch as safe
import torch
from fsspec.spec import Any
from jaxtyping import Float
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor
Expand Down Expand Up @@ -110,6 +109,7 @@ def topk_activation(x: torch.Tensor):
k = x.shape[-1] - self.current_k + 1
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1)
k_th_value = k_th_value.unsqueeze(dim=1)
print()
return x.ge(k_th_value)

return topk_activation
Expand Down
11 changes: 10 additions & 1 deletion src/lm_saes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lm_saes.config import TrainerConfig
from lm_saes.optim import get_scheduler
from lm_saes.sae import SparseAutoEncoder
from lm_saes.utils.misc import all_reduce_tensor


class Trainer:
Expand Down Expand Up @@ -85,7 +86,7 @@ def _training_step(
sae: SparseAutoEncoder,
batch: dict[str, Tensor],
) -> dict[str, Tensor]:
if (not sae.cfg.act_fn == "topk") and self.l1_coefficient_warmup_steps > 0:
if "topk" not in sae.cfg.act_fn and self.l1_coefficient_warmup_steps > 0:
assert self.cfg.l1_coefficient is not None
sae.set_current_l1_coefficient(
min(1.0, self.cur_step / self.l1_coefficient_warmup_steps) * self.cfg.l1_coefficient
Expand Down Expand Up @@ -133,6 +134,14 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
"sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(),
"sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(),
}
if sae.cfg.sae_type == 'crosscoder':
wandb_log_dict.update({
"sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-1).sum().item(),
"sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-2).sum().item(),
"sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-5).sum().item(),
"sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-6).sum().item(),
})

self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1)
log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"])
log_info["n_frac_active_tokens"] = torch.zeros_like(log_info["n_frac_active_tokens"])
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,4 @@ def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
{"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)}
)
output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype))
assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))
assert torch.allclose(output, torch.tensor([[212.0, 449.0]], device=sae_config.device, dtype=sae_config.dtype))
Loading