Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TransformerLens
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ npu = [

triton = ["triton"]

sae_lens = ["sae-lens>=6.22.3"]

[[tool.uv.index]]
name = "torch-cpu"
url = "https://download.pytorch.org/whl/cpu"
Expand Down Expand Up @@ -187,5 +189,4 @@ version = "2.7.4.post1"
requires-dist = ["torch", "einops"]

[tool.uv.sources.transformer-lens]
path = "./TransformerLens"
editable = true
path = "./TransformerLens"
7 changes: 4 additions & 3 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class BaseSAEConfig(BaseModelConfig, ABC):

sae_type: Literal["sae", "crosscoder", "clt", "lorsa", "molt"]
d_model: int
expansion_factor: int
expansion_factor: float
use_decoder_bias: bool = True
act_fn: Literal["relu", "jumprelu", "topk", "batchtopk", "batchlayertopk", "layertopk"] = "relu"
norm_activation: str = "dataset-wise"
Expand All @@ -81,7 +81,8 @@ class BaseSAEConfig(BaseModelConfig, ABC):

@property
def d_sae(self) -> int:
return self.d_model * self.expansion_factor
d_sae = int(self.d_model * self.expansion_factor)
return d_sae

@classmethod
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
Expand Down Expand Up @@ -267,7 +268,7 @@ def generate_rank_assignments(self) -> list[int]:
assert self.rank_distribution, "rank_distribution cannot be empty"

# Calculate base d_sae
base_d_sae = self.d_model * self.expansion_factor
base_d_sae = self.d_sae

# For distributed training, use special logic to ensure consistency
if self.model_parallel_size_training > 1:
Expand Down
143 changes: 141 additions & 2 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformer_lens.hook_points import HookPoint
from typing_extensions import override

from lm_saes.activation_functions import JumpReLU
from lm_saes.utils.distributed import DimMap
from lm_saes.utils.logging import get_distributed_logger

Expand Down Expand Up @@ -332,9 +333,9 @@ def decode(
return reconstructed

@classmethod
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, fold_activation_scale:bool = True, **kwargs):
cfg = SAEConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs)
return cls.from_config(cfg)
return cls.from_config(cfg, fold_activation_scale=fold_activation_scale)

@torch.no_grad()
def _init_encoder_with_decoder_transpose(
Expand Down Expand Up @@ -417,3 +418,141 @@ def init_encoder_bias_with_mean_hidden_pre(self, activation_batch: dict[str, tor
x = self.prepare_input(activation_batch)[0]
_, hidden_pre = self.encode(x, return_hidden_pre=True)
self.b_E.copy_(-hidden_pre.mean(dim=0))

@classmethod
@torch.no_grad()
def from_saelens(cls, sae_saelens):
from sae_lens import JumpReLUSAE, StandardSAE, TopKSAE

# Check Configuration
assert isinstance(sae_saelens, JumpReLUSAE) or isinstance(sae_saelens, StandardSAE) or isinstance(sae_saelens, TopKSAE), f'Only support JumpReLUSAE, StandardSAE, TopKSAE, but get {type(sae_saelens)}'
assert sae_saelens.cfg.reshape_activations == 'none', f"The 'reshape_activations' should be 'none' but get {sae_saelens.cfg.reshape_activations}."
assert not sae_saelens.cfg.apply_b_dec_to_input, f"The 'apply_b_dec_to_input' should be 'False' but get {sae_saelens.cfg.apply_b_dec_to_input}."
assert sae_saelens.cfg.normalize_activations == 'none', f"The 'normalize_activations' should be 'false' but get {sae_saelens.cfg.normalize_activations}."

# Parse
d_model = sae_saelens.cfg.d_in
d_sae = sae_saelens.cfg.d_sae
hook_name = sae_saelens.cfg.metadata.hook_name
assert isinstance(hook_name, str)
dtype = sae_saelens.W_enc.dtype

rescale_acts_by_decoder_norm = False
jumprelu_threshold_window = 0
k = 0
if isinstance(sae_saelens, StandardSAE):
activation_fn = 'relu'
elif isinstance(sae_saelens, TopKSAE):
activation_fn = 'topk'
k = sae_saelens.cfg.k
rescale_acts_by_decoder_norm = sae_saelens.cfg.rescale_acts_by_decoder_norm
elif isinstance(sae_saelens, JumpReLUSAE):
activation_fn = 'jumprelu'
jumprelu_threshold_window = 2
else:
raise TypeError(f'Only support JumpReLUSAE, StandardSAE, TopKSAE, but get {type(sae_saelens)}')

# create cfg
cfg = SAEConfig(
sae_type = "sae",
hook_point_in = hook_name,
hook_point_out = hook_name,
dtype = dtype,
d_model = d_model,
act_fn = activation_fn,
jumprelu_threshold_window=jumprelu_threshold_window,
top_k = k,
expansion_factor = d_sae / d_model,
sparsity_include_decoder_norm = rescale_acts_by_decoder_norm,
)

model = cls.from_config(cfg, None)

model.W_D.copy_(sae_saelens.W_dec)
model.W_E.copy_(sae_saelens.W_enc)
model.b_D.copy_(sae_saelens.b_dec)
model.b_E.copy_(sae_saelens.b_enc)

if isinstance(sae_saelens, JumpReLUSAE):
assert isinstance(model.activation_function, JumpReLU)
model.activation_function.log_jumprelu_threshold.copy_(torch.log(sae_saelens.threshold.clone().detach()))

return model

@torch.no_grad()
def to_saelens(self, model_name:str='unknown'):
from sae_lens import JumpReLUSAE, JumpReLUSAEConfig, StandardSAE, StandardSAEConfig, TopKSAE, TopKSAEConfig
from sae_lens.saes.sae import SAEMetadata


# Check env
assert self.cfg.hook_point_in != self.cfg.hook_point_out, "Not support transcoder yet."
assert not isinstance(self.b_D, DTensor), "Not support distributed setting yet."
assert not self.cfg.use_glu_encoder, "Can't convert sae with use_glu_encoder=True to SAE Lens format."

# Parse
d_in = self.cfg.d_model
d_sae = self.cfg.d_sae
activation_fn = self.cfg.act_fn
hook_name = self.cfg.hook_point_in
dtype = self.cfg.dtype

# Create model
if activation_fn == 'relu':
cfg_saelens = StandardSAEConfig(
d_in=d_in,
d_sae=d_sae,
dtype=str(dtype),
device="cpu",
apply_b_dec_to_input=False,
normalize_activations="none",
metadata=SAEMetadata(
model_name=model_name,
hook_name=hook_name,
),
)
model = StandardSAE(cfg_saelens)
elif activation_fn == 'jumprelu':
cfg_saelens = JumpReLUSAEConfig(
d_in=d_in,
d_sae=d_sae,
dtype=str(dtype),
device="cpu",
apply_b_dec_to_input=False,
normalize_activations="none",
metadata=SAEMetadata(
model_name=model_name,
hook_name=hook_name,
),
)
model = JumpReLUSAE(cfg_saelens)
elif activation_fn == 'topk':
cfg_saelens = TopKSAEConfig(
k=self.cfg.top_k,
d_in=d_in,
d_sae=d_sae,
dtype=str(dtype),
device="cpu",
apply_b_dec_to_input=False,
normalize_activations="none",
rescale_acts_by_decoder_norm=self.cfg.sparsity_include_decoder_norm,
metadata=SAEMetadata(
model_name=model_name,
hook_name=hook_name,
),
)
model = TopKSAE(cfg_saelens)
else:
raise TypeError("Not support such activation function yet.")

# Depulicate weights
model.W_dec.copy_(self.W_D)
model.W_enc.copy_(self.W_E)
model.b_dec.copy_(self.b_D)
model.b_enc.copy_(self.b_E)

if isinstance(model, JumpReLUSAE):
assert isinstance(self.activation_function, JumpReLU)
model.threshold.copy_(self.activation_function.log_jumprelu_threshold.exp())

return model
Loading