Skip to content

Commit 4eb40ba

Browse files
dest1n1sFrankstein73
authored andcommitted
test(sae): fix test SAE fixture; specify test weight for exact-checking forward computation
1 parent 1cbcd25 commit 4eb40ba

File tree

2 files changed

+61
-43
lines changed

2 files changed

+61
-43
lines changed

src/lm_saes/sae.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
9292
decoder_norm = decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
9393
return decoder_norm
9494

95-
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
95+
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]:
9696
assert cfg.act_fn.lower() in [
9797
"relu",
9898
"topk",
@@ -127,7 +127,9 @@ def topk_activation(x: torch.Tensor):
127127

128128
return topk_activation
129129

130-
def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore
130+
raise ValueError(f"Not implemented activation function {cfg.act_fn}")
131+
132+
def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
131133
"""Compute the normalization factor for the activation vectors.
132134
This should be called during forward pass.
133135
There are four modes for norm_activation:
@@ -157,6 +159,7 @@ def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
157159
)
158160
if self.cfg.norm_activation == "inference":
159161
return torch.tensor(1.0, device=x.device, dtype=x.dtype)
162+
raise ValueError(f"Not implemented norm_activation {self.cfg.norm_activation}")
160163

161164
@torch.no_grad()
162165
def _set_decoder_to_fixed_norm(self, decoder: torch.nn.Linear, value: float, force_exact: bool):
@@ -371,26 +374,51 @@ def encode(
371374
Float[torch.Tensor, "batch seq_len d_sae"],
372375
],
373376
],
374-
]: # should be overridden by subclasses
377+
]:
378+
"""Encode input tensor through the sparse autoencoder.
379+
380+
Args:
381+
x: Input tensor of shape (batch, d_model) or (batch, seq_len, d_model)
382+
return_hidden_pre: If True, also return the pre-activation hidden states
383+
384+
Returns:
385+
If return_hidden_pre is False:
386+
Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)
387+
If return_hidden_pre is True:
388+
Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae)
389+
"""
390+
# Apply input normalization based on config
375391
input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in)
376392
x = x * input_norm_factor
393+
394+
# Optionally subtract decoder bias before encoding
377395
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
378396
# We need to convert decoder bias to a tensor before subtracting
379397
bias = self.decoder.bias.to_local() if isinstance(self.decoder.bias, DTensor) else self.decoder.bias
380398
x = x - bias
399+
400+
# Pass through encoder
381401
hidden_pre = self.encoder(x)
402+
# Apply GLU if configured
382403
if self.cfg.use_glu_encoder:
383404
hidden_pre_glu = torch.sigmoid(self.encoder_glu(x))
384405
hidden_pre = hidden_pre * hidden_pre_glu
385406

386407
hidden_pre = self.hook_hidden_pre(hidden_pre)
408+
409+
# Scale feature activations by decoder norm if configured
387410
if self.cfg.sparsity_include_decoder_norm:
388-
true_feature_acts = hidden_pre * self._decoder_norm(decoder=self.decoder)
411+
sparsity_scores = hidden_pre * self._decoder_norm(decoder=self.decoder)
389412
else:
390-
true_feature_acts = hidden_pre
391-
activation_mask = self.activation_function(true_feature_acts)
413+
sparsity_scores = hidden_pre
414+
415+
# Apply activation function. The activation function here differs from a common activation function,
416+
# since it computes a scaling of the input tensor, which is, suppose the common activation function
417+
# is $f(x)$, then here it computes $f(x) / x$. For simple ReLU case, it computes a mask of 1s and 0s.
418+
activation_mask = self.activation_function(sparsity_scores)
392419
feature_acts = hidden_pre * activation_mask
393420
feature_acts = self.hook_feature_acts(feature_acts)
421+
394422
if return_hidden_pre:
395423
return feature_acts, hidden_pre
396424
return feature_acts

tests/unit/test_sae.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,30 @@ def generator(sae_config: SAEConfig) -> torch.Generator:
3232
@pytest.fixture
3333
def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder:
3434
sae = SparseAutoEncoder(sae_config)
35-
sae.encoder.weight.data = torch.randn(
36-
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
35+
sae.encoder.weight.data = torch.tensor(
36+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
37+
requires_grad=True,
38+
dtype=sae_config.dtype,
39+
device=sae_config.device,
3740
)
38-
sae.decoder.weight.data = torch.randn(
39-
sae_config.d_model, sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
41+
sae.encoder.bias.data = torch.tensor(
42+
[3.0, 2.0, 3.0, 4.0],
43+
requires_grad=True,
44+
dtype=sae_config.dtype,
45+
device=sae_config.device,
46+
)
47+
sae.decoder.weight.data = torch.tensor(
48+
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
49+
requires_grad=True,
50+
dtype=sae_config.dtype,
51+
device=sae_config.device,
52+
)
53+
sae.decoder.bias.data = torch.tensor(
54+
[1.0, 2.0],
55+
requires_grad=True,
56+
dtype=sae_config.dtype,
57+
device=sae_config.device,
4058
)
41-
if sae_config.use_decoder_bias:
42-
sae.decoder.bias.data = torch.randn(
43-
sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
44-
)
45-
if sae_config.use_glu_encoder:
46-
sae.encoder_glu.weight.data = torch.randn(
47-
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
48-
)
49-
sae.encoder_glu.bias.data = torch.randn(
50-
sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
51-
)
5259
return sae
5360

5461

@@ -196,27 +203,8 @@ def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder):
196203

197204
def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: SparseAutoEncoder):
198205
sae_config.norm_activation = "dataset-wise"
199-
sae.encoder.bias.data = torch.tensor(
200-
[[1.0, 2.0]],
201-
requires_grad=True,
202-
dtype=sae_config.dtype,
203-
device=sae_config.device,
204-
)
205206
encoder_bias_data = sae.encoder.bias.data.clone()
206-
sae.decoder.weight.data = torch.tensor(
207-
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
208-
requires_grad=True,
209-
dtype=sae_config.dtype,
210-
device=sae_config.device,
211-
)
212207
decoder_weight_data = sae.decoder.weight.data.clone()
213-
if sae_config.use_decoder_bias:
214-
sae.decoder.bias.data = torch.tensor(
215-
[[1.0, 2.0, 3.0, 4.0]],
216-
requires_grad=True,
217-
dtype=sae_config.dtype,
218-
device=sae_config.device,
219-
)
220208
decoder_bias_data = sae.decoder.bias.data.clone()
221209
sae.standardize_parameters_of_dataset_norm({"in": 3.0, "out": 2.0})
222210
assert sae.cfg.norm_activation == "inference"
@@ -237,6 +225,8 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar
237225

238226

239227
def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
240-
sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0})
241-
output = sae.forward(torch.tensor([[1.0, 2.0]], device=sae_config.device, dtype=sae_config.dtype))
242-
assert output.shape == (1, 2)
228+
sae.set_dataset_average_activation_norm(
229+
{"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)}
230+
)
231+
output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype))
232+
assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))

0 commit comments

Comments
 (0)