Skip to content

Commit 5240735

Browse files
committed
fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode
1 parent 50585a7 commit 5240735

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

src/lm_saes/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ def save_hyperparameters(self, sae_path: Path | str, remove_loading_info: bool =
9595

9696

9797
class SAEConfig(BaseSAEConfig):
98-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'sae'
99-
98+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "sae"
99+
100100

101101
class CrossCoderConfig(BaseSAEConfig):
102-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'crosscoder'
103-
102+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "crosscoder"
103+
104104

105105
class MixCoderConfig(BaseSAEConfig):
106-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'mixcoder'
106+
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder"
107107
d_single_modal: int
108108
d_shared: int
109109
n_modalities: int = 2

src/lm_saes/crosscoder.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,18 @@ class CrossCoder(SparseAutoEncoder):
2020
def __init__(self, cfg: BaseSAEConfig):
2121
super(CrossCoder, self).__init__(cfg)
2222

23-
def _decoder_norm(
24-
self,
25-
decoder: torch.nn.Linear,
26-
keepdim: bool = False,
27-
local_only=True,
28-
aggregate="none"
29-
):
23+
def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False, local_only=True, aggregate="none"):
3024
decoder_norm = super()._decoder_norm(
3125
decoder=decoder,
3226
keepdim=keepdim,
3327
)
3428
if not local_only:
3529
decoder_norm = all_reduce_tensor(
36-
decoder_norm,
30+
decoder_norm,
3731
aggregate=aggregate,
3832
)
3933
return decoder_norm
40-
34+
4135
@overload
4236
def encode(
4337
self,
@@ -110,7 +104,7 @@ def encode(
110104

111105
hidden_pre = all_reduce_tensor(hidden_pre, aggregate="sum")
112106
hidden_pre = self.hook_hidden_pre(hidden_pre)
113-
107+
114108
if self.cfg.sparsity_include_decoder_norm:
115109
true_feature_acts = hidden_pre * self._decoder_norm(
116110
decoder=self.decoder,
@@ -127,7 +121,7 @@ def encode(
127121
if return_hidden_pre:
128122
return feature_acts, hidden_pre
129123
return feature_acts
130-
124+
131125
@overload
132126
def compute_loss(
133127
self,
@@ -229,4 +223,3 @@ def initialize_with_same_weight_across_layers(self):
229223
self.encoder.bias.data = get_tensor_from_specific_rank(self.encoder.bias.data.clone(), src=0)
230224
self.decoder.weight.data = get_tensor_from_specific_rank(self.decoder.weight.data.clone(), src=0)
231225
self.decoder.bias.data = get_tensor_from_specific_rank(self.decoder.bias.data.clone(), src=0)
232-

src/lm_saes/sae.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,12 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
9292
return decoder_norm
9393

9494
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
95-
assert cfg.act_fn.lower() in ["relu", "topk", "jumprelu", "batchtopk"], f"Not implemented activation function {cfg.act_fn}"
95+
assert cfg.act_fn.lower() in [
96+
"relu",
97+
"topk",
98+
"jumprelu",
99+
"batchtopk",
100+
], f"Not implemented activation function {cfg.act_fn}"
96101
if cfg.act_fn.lower() == "relu":
97102
return lambda x: x.gt(0).to(x.dtype)
98103
elif cfg.act_fn.lower() == "jumprelu":
@@ -106,17 +111,18 @@ def topk_activation(x: torch.Tensor):
106111
return x.ge(k_th_value).to(x.dtype)
107112

108113
return topk_activation
109-
114+
110115
elif cfg.act_fn.lower() == "batchtopk":
116+
111117
def topk_activation(x: torch.Tensor):
112118
assert x.dim() == 2
113119
batch_size = x.size(0)
114-
120+
115121
x = torch.clamp(x, min=0.0)
116122
k = x.numel() - self.current_k * batch_size + 1
117123
k_th_value, _ = torch.kthvalue(x.flatten(), k=k, dim=-1)
118124
return x.ge(k_th_value).to(x.dtype)
119-
125+
120126
return topk_activation
121127

122128
def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore

0 commit comments

Comments
 (0)