Skip to content

Commit 5c38684

Browse files
feat: conversion methods between lm-saes and saelens
Co-authored-by: Guancheng Zhou <[email protected]>
1 parent 242f7a9 commit 5c38684

File tree

3 files changed

+148
-6
lines changed

3 files changed

+148
-6
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ npu = [
113113

114114
triton = ["triton"]
115115

116+
sae_lens = ["sae-lens>=6.22.3"]
117+
116118
[[tool.uv.index]]
117119
name = "torch-cpu"
118120
url = "https://download.pytorch.org/whl/cpu"
@@ -188,4 +190,4 @@ requires-dist = ["torch", "einops"]
188190

189191
[tool.uv.sources.transformer-lens]
190192
path = "./TransformerLens"
191-
editable = true
193+
editable = true

src/lm_saes/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class BaseSAEConfig(BaseModelConfig, ABC):
5454

5555
sae_type: Literal["sae", "crosscoder", "clt", "lorsa", "molt"]
5656
d_model: int
57-
expansion_factor: int
57+
expansion_factor: float
5858
use_decoder_bias: bool = True
5959
act_fn: Literal["relu", "jumprelu", "topk", "batchtopk", "batchlayertopk", "layertopk"] = "relu"
6060
norm_activation: str = "dataset-wise"
@@ -81,7 +81,8 @@ class BaseSAEConfig(BaseModelConfig, ABC):
8181

8282
@property
8383
def d_sae(self) -> int:
84-
return self.d_model * self.expansion_factor
84+
d_sae = int(self.d_model * self.expansion_factor)
85+
return d_sae
8586

8687
@classmethod
8788
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
@@ -267,7 +268,7 @@ def generate_rank_assignments(self) -> list[int]:
267268
assert self.rank_distribution, "rank_distribution cannot be empty"
268269

269270
# Calculate base d_sae
270-
base_d_sae = self.d_model * self.expansion_factor
271+
base_d_sae = self.d_sae
271272

272273
# For distributed training, use special logic to ensure consistency
273274
if self.model_parallel_size_training > 1:

src/lm_saes/sae.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformer_lens.hook_points import HookPoint
1111
from typing_extensions import override
1212

13+
from lm_saes.activation_functions import JumpReLU
1314
from lm_saes.utils.distributed import DimMap
1415
from lm_saes.utils.logging import get_distributed_logger
1516

@@ -332,9 +333,9 @@ def decode(
332333
return reconstructed
333334

334335
@classmethod
335-
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, **kwargs):
336+
def from_pretrained(cls, pretrained_name_or_path: str, strict_loading: bool = True, fold_activation_scale:bool = True, **kwargs):
336337
cfg = SAEConfig.from_pretrained(pretrained_name_or_path, strict_loading=strict_loading, **kwargs)
337-
return cls.from_config(cfg)
338+
return cls.from_config(cfg, fold_activation_scale=fold_activation_scale)
338339

339340
@torch.no_grad()
340341
def _init_encoder_with_decoder_transpose(
@@ -417,3 +418,141 @@ def init_encoder_bias_with_mean_hidden_pre(self, activation_batch: dict[str, tor
417418
x = self.prepare_input(activation_batch)[0]
418419
_, hidden_pre = self.encode(x, return_hidden_pre=True)
419420
self.b_E.copy_(-hidden_pre.mean(dim=0))
421+
422+
@classmethod
423+
@torch.no_grad()
424+
def from_saelens(cls, sae_saelens):
425+
from sae_lens import JumpReLUSAE, StandardSAE, TopKSAE
426+
427+
# Check Configuration
428+
assert isinstance(sae_saelens, JumpReLUSAE) or isinstance(sae_saelens, StandardSAE) or isinstance(sae_saelens, TopKSAE), f'Only support JumpReLUSAE, StandardSAE, TopKSAE, but get {type(sae_saelens)}'
429+
assert sae_saelens.cfg.reshape_activations == 'none', f"The 'reshape_activations' should be 'none' but get {sae_saelens.cfg.reshape_activations}."
430+
assert not sae_saelens.cfg.apply_b_dec_to_input, f"The 'apply_b_dec_to_input' should be 'False' but get {sae_saelens.cfg.apply_b_dec_to_input}."
431+
assert sae_saelens.cfg.normalize_activations == 'none', f"The 'normalize_activations' should be 'false' but get {sae_saelens.cfg.normalize_activations}."
432+
433+
# Parse
434+
d_model = sae_saelens.cfg.d_in
435+
d_sae = sae_saelens.cfg.d_sae
436+
hook_name = sae_saelens.cfg.metadata.hook_name
437+
assert isinstance(hook_name, str)
438+
dtype = sae_saelens.W_enc.dtype
439+
440+
rescale_acts_by_decoder_norm = False
441+
jumprelu_threshold_window = 0
442+
k = 0
443+
if isinstance(sae_saelens, StandardSAE):
444+
activation_fn = 'relu'
445+
elif isinstance(sae_saelens, TopKSAE):
446+
activation_fn = 'topk'
447+
k = sae_saelens.cfg.k
448+
rescale_acts_by_decoder_norm = sae_saelens.cfg.rescale_acts_by_decoder_norm
449+
elif isinstance(sae_saelens, JumpReLUSAE):
450+
activation_fn = 'jumprelu'
451+
jumprelu_threshold_window = 2
452+
else:
453+
raise TypeError(f'Only support JumpReLUSAE, StandardSAE, TopKSAE, but get {type(sae_saelens)}')
454+
455+
# create cfg
456+
cfg = SAEConfig(
457+
sae_type = "sae",
458+
hook_point_in = hook_name,
459+
hook_point_out = hook_name,
460+
dtype = dtype,
461+
d_model = d_model,
462+
act_fn = activation_fn,
463+
jumprelu_threshold_window=jumprelu_threshold_window,
464+
top_k = k,
465+
expansion_factor = d_sae / d_model,
466+
sparsity_include_decoder_norm = rescale_acts_by_decoder_norm,
467+
)
468+
469+
model = cls.from_config(cfg, None)
470+
471+
model.W_D.copy_(sae_saelens.W_dec)
472+
model.W_E.copy_(sae_saelens.W_enc)
473+
model.b_D.copy_(sae_saelens.b_dec)
474+
model.b_E.copy_(sae_saelens.b_enc)
475+
476+
if isinstance(sae_saelens, JumpReLUSAE):
477+
assert isinstance(model.activation_function, JumpReLU)
478+
model.activation_function.log_jumprelu_threshold.copy_(torch.log(sae_saelens.threshold.clone().detach()))
479+
480+
return model
481+
482+
@torch.no_grad()
483+
def to_saelens(self, model_name:str='unknown'):
484+
from sae_lens import JumpReLUSAE, JumpReLUSAEConfig, StandardSAE, StandardSAEConfig, TopKSAE, TopKSAEConfig
485+
from sae_lens.saes.sae import SAEMetadata
486+
487+
488+
# Check env
489+
assert self.cfg.hook_point_in != self.cfg.hook_point_out, "Not support transcoder yet."
490+
assert not isinstance(self.b_D, DTensor), "Not support distributed setting yet."
491+
assert not self.cfg.use_glu_encoder, "Can't convert sae with use_glu_encoder=True to SAE Lens format."
492+
493+
# Parse
494+
d_in = self.cfg.d_model
495+
d_sae = self.cfg.d_sae
496+
activation_fn = self.cfg.act_fn
497+
hook_name = self.cfg.hook_point_in
498+
dtype = self.cfg.dtype
499+
500+
# Create model
501+
if activation_fn == 'relu':
502+
cfg_saelens = StandardSAEConfig(
503+
d_in=d_in,
504+
d_sae=d_sae,
505+
dtype=str(dtype),
506+
device="cpu",
507+
apply_b_dec_to_input=False,
508+
normalize_activations="none",
509+
metadata=SAEMetadata(
510+
model_name=model_name,
511+
hook_name=hook_name,
512+
),
513+
)
514+
model = StandardSAE(cfg_saelens)
515+
elif activation_fn == 'jumprelu':
516+
cfg_saelens = JumpReLUSAEConfig(
517+
d_in=d_in,
518+
d_sae=d_sae,
519+
dtype=str(dtype),
520+
device="cpu",
521+
apply_b_dec_to_input=False,
522+
normalize_activations="none",
523+
metadata=SAEMetadata(
524+
model_name=model_name,
525+
hook_name=hook_name,
526+
),
527+
)
528+
model = JumpReLUSAE(cfg_saelens)
529+
elif activation_fn == 'topk':
530+
cfg_saelens = TopKSAEConfig(
531+
k=self.cfg.top_k,
532+
d_in=d_in,
533+
d_sae=d_sae,
534+
dtype=str(dtype),
535+
device="cpu",
536+
apply_b_dec_to_input=False,
537+
normalize_activations="none",
538+
rescale_acts_by_decoder_norm=self.cfg.sparsity_include_decoder_norm,
539+
metadata=SAEMetadata(
540+
model_name=model_name,
541+
hook_name=hook_name,
542+
),
543+
)
544+
model = TopKSAE(cfg_saelens)
545+
else:
546+
raise TypeError("Not support such activation function yet.")
547+
548+
# Depulicate weights
549+
model.W_dec.copy_(self.W_D)
550+
model.W_enc.copy_(self.W_E)
551+
model.b_dec.copy_(self.b_D)
552+
model.b_enc.copy_(self.b_E)
553+
554+
if isinstance(model, JumpReLUSAE):
555+
assert isinstance(self.activation_function, JumpReLU)
556+
model.threshold.copy_(self.activation_function.log_jumprelu_threshold.exp())
557+
558+
return model

0 commit comments

Comments
 (0)