Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
decoder_norm = decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
return decoder_norm

def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]:
assert cfg.act_fn.lower() in [
"relu",
"topk",
Expand Down Expand Up @@ -125,7 +125,9 @@ def topk_activation(x: torch.Tensor):

return topk_activation

def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore
raise ValueError(f"Not implemented activation function {cfg.act_fn}")

def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
"""Compute the normalization factor for the activation vectors.
This should be called during forward pass.
There are four modes for norm_activation:
Expand Down Expand Up @@ -155,6 +157,7 @@ def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
)
if self.cfg.norm_activation == "inference":
return torch.tensor(1.0, device=x.device, dtype=x.dtype)
raise ValueError(f"Not implemented norm_activation {self.cfg.norm_activation}")

@torch.no_grad()
def _set_decoder_to_fixed_norm(self, decoder: torch.nn.Linear, value: float, force_exact: bool):
Expand Down Expand Up @@ -366,26 +369,51 @@ def encode(
Float[torch.Tensor, "batch seq_len d_sae"],
],
],
]: # should be overridden by subclasses
]:
"""Encode input tensor through the sparse autoencoder.

Args:
x: Input tensor of shape (batch, d_model) or (batch, seq_len, d_model)
return_hidden_pre: If True, also return the pre-activation hidden states

Returns:
If return_hidden_pre is False:
Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)
If return_hidden_pre is True:
Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae)
"""
# Apply input normalization based on config
input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in)
x = x * input_norm_factor

# Optionally subtract decoder bias before encoding
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
# We need to convert decoder bias to a tensor before subtracting
bias = self.decoder.bias.to_local() if isinstance(self.decoder.bias, DTensor) else self.decoder.bias
x = x - bias

# Pass through encoder
hidden_pre = self.encoder(x)
# Apply GLU if configured
if self.cfg.use_glu_encoder:
hidden_pre_glu = torch.sigmoid(self.encoder_glu(x))
hidden_pre = hidden_pre * hidden_pre_glu

hidden_pre = self.hook_hidden_pre(hidden_pre)

# Scale feature activations by decoder norm if configured
if self.cfg.sparsity_include_decoder_norm:
true_feature_acts = hidden_pre * self._decoder_norm(decoder=self.decoder)
sparsity_scores = hidden_pre * self._decoder_norm(decoder=self.decoder)
else:
true_feature_acts = hidden_pre
activation_mask = self.activation_function(true_feature_acts)
sparsity_scores = hidden_pre

# Apply activation function. The activation function here differs from a common activation function,
# since it computes a scaling of the input tensor, which is, suppose the common activation function
# is $f(x)$, then here it computes $f(x) / x$. For simple ReLU case, it computes a mask of 1s and 0s.
activation_mask = self.activation_function(sparsity_scores)
feature_acts = hidden_pre * activation_mask
feature_acts = self.hook_feature_acts(feature_acts)

if return_hidden_pre:
return feature_acts, hidden_pre
return feature_acts
Expand Down
64 changes: 27 additions & 37 deletions tests/unit/test_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,30 @@ def generator(sae_config: SAEConfig) -> torch.Generator:
@pytest.fixture
def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder:
sae = SparseAutoEncoder(sae_config)
sae.encoder.weight.data = torch.randn(
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
sae.encoder.weight.data = torch.tensor(
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.weight.data = torch.randn(
sae_config.d_model, sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
sae.encoder.bias.data = torch.tensor(
[3.0, 2.0, 3.0, 4.0],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.weight.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.bias.data = torch.tensor(
[1.0, 2.0],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
if sae_config.use_decoder_bias:
sae.decoder.bias.data = torch.randn(
sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
if sae_config.use_glu_encoder:
sae.encoder_glu.weight.data = torch.randn(
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
sae.encoder_glu.bias.data = torch.randn(
sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
return sae


Expand Down Expand Up @@ -196,27 +203,8 @@ def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder):

def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: SparseAutoEncoder):
sae_config.norm_activation = "dataset-wise"
sae.encoder.bias.data = torch.tensor(
[[1.0, 2.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
encoder_bias_data = sae.encoder.bias.data.clone()
sae.decoder.weight.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
decoder_weight_data = sae.decoder.weight.data.clone()
if sae_config.use_decoder_bias:
sae.decoder.bias.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
decoder_bias_data = sae.decoder.bias.data.clone()
sae.standardize_parameters_of_dataset_norm({"in": 3.0, "out": 2.0})
assert sae.cfg.norm_activation == "inference"
Expand All @@ -237,6 +225,8 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar


def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0})
output = sae.forward(torch.tensor([[1.0, 2.0]], device=sae_config.device, dtype=sae_config.dtype))
assert output.shape == (1, 2)
sae.set_dataset_average_activation_norm(
{"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)}
)
output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype))
assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))
Loading