diff --git a/src/lm_saes/sae.py b/src/lm_saes/sae.py index 7a8e667a..1a1b2a0d 100644 --- a/src/lm_saes/sae.py +++ b/src/lm_saes/sae.py @@ -91,7 +91,7 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False): decoder_norm = decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local() return decoder_norm - def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore + def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: assert cfg.act_fn.lower() in [ "relu", "topk", @@ -125,7 +125,9 @@ def topk_activation(x: torch.Tensor): return topk_activation - def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore + raise ValueError(f"Not implemented activation function {cfg.act_fn}") + + def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: """Compute the normalization factor for the activation vectors. This should be called during forward pass. There are four modes for norm_activation: @@ -155,6 +157,7 @@ def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: ) if self.cfg.norm_activation == "inference": return torch.tensor(1.0, device=x.device, dtype=x.dtype) + raise ValueError(f"Not implemented norm_activation {self.cfg.norm_activation}") @torch.no_grad() def _set_decoder_to_fixed_norm(self, decoder: torch.nn.Linear, value: float, force_exact: bool): @@ -366,26 +369,51 @@ def encode( Float[torch.Tensor, "batch seq_len d_sae"], ], ], - ]: # should be overridden by subclasses + ]: + """Encode input tensor through the sparse autoencoder. + + Args: + x: Input tensor of shape (batch, d_model) or (batch, seq_len, d_model) + return_hidden_pre: If True, also return the pre-activation hidden states + + Returns: + If return_hidden_pre is False: + Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae) + If return_hidden_pre is True: + Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae) + """ + # Apply input normalization based on config input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in) x = x * input_norm_factor + + # Optionally subtract decoder bias before encoding if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder: # We need to convert decoder bias to a tensor before subtracting bias = self.decoder.bias.to_local() if isinstance(self.decoder.bias, DTensor) else self.decoder.bias x = x - bias + + # Pass through encoder hidden_pre = self.encoder(x) + # Apply GLU if configured if self.cfg.use_glu_encoder: hidden_pre_glu = torch.sigmoid(self.encoder_glu(x)) hidden_pre = hidden_pre * hidden_pre_glu hidden_pre = self.hook_hidden_pre(hidden_pre) + + # Scale feature activations by decoder norm if configured if self.cfg.sparsity_include_decoder_norm: - true_feature_acts = hidden_pre * self._decoder_norm(decoder=self.decoder) + sparsity_scores = hidden_pre * self._decoder_norm(decoder=self.decoder) else: - true_feature_acts = hidden_pre - activation_mask = self.activation_function(true_feature_acts) + sparsity_scores = hidden_pre + + # Apply activation function. The activation function here differs from a common activation function, + # since it computes a scaling of the input tensor, which is, suppose the common activation function + # is $f(x)$, then here it computes $f(x) / x$. For simple ReLU case, it computes a mask of 1s and 0s. + activation_mask = self.activation_function(sparsity_scores) feature_acts = hidden_pre * activation_mask feature_acts = self.hook_feature_acts(feature_acts) + if return_hidden_pre: return feature_acts, hidden_pre return feature_acts diff --git a/tests/unit/test_sae.py b/tests/unit/test_sae.py index 8721e41b..593a90cc 100644 --- a/tests/unit/test_sae.py +++ b/tests/unit/test_sae.py @@ -32,23 +32,30 @@ def generator(sae_config: SAEConfig) -> torch.Generator: @pytest.fixture def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder: sae = SparseAutoEncoder(sae_config) - sae.encoder.weight.data = torch.randn( - sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype + sae.encoder.weight.data = torch.tensor( + [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], + requires_grad=True, + dtype=sae_config.dtype, + device=sae_config.device, ) - sae.decoder.weight.data = torch.randn( - sae_config.d_model, sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype + sae.encoder.bias.data = torch.tensor( + [3.0, 2.0, 3.0, 4.0], + requires_grad=True, + dtype=sae_config.dtype, + device=sae_config.device, + ) + sae.decoder.weight.data = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + requires_grad=True, + dtype=sae_config.dtype, + device=sae_config.device, + ) + sae.decoder.bias.data = torch.tensor( + [1.0, 2.0], + requires_grad=True, + dtype=sae_config.dtype, + device=sae_config.device, ) - if sae_config.use_decoder_bias: - sae.decoder.bias.data = torch.randn( - sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype - ) - if sae_config.use_glu_encoder: - sae.encoder_glu.weight.data = torch.randn( - sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype - ) - sae.encoder_glu.bias.data = torch.randn( - sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype - ) return sae @@ -196,27 +203,8 @@ def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder): def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: SparseAutoEncoder): sae_config.norm_activation = "dataset-wise" - sae.encoder.bias.data = torch.tensor( - [[1.0, 2.0]], - requires_grad=True, - dtype=sae_config.dtype, - device=sae_config.device, - ) encoder_bias_data = sae.encoder.bias.data.clone() - sae.decoder.weight.data = torch.tensor( - [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], - requires_grad=True, - dtype=sae_config.dtype, - device=sae_config.device, - ) decoder_weight_data = sae.decoder.weight.data.clone() - if sae_config.use_decoder_bias: - sae.decoder.bias.data = torch.tensor( - [[1.0, 2.0, 3.0, 4.0]], - requires_grad=True, - dtype=sae_config.dtype, - device=sae_config.device, - ) decoder_bias_data = sae.decoder.bias.data.clone() sae.standardize_parameters_of_dataset_norm({"in": 3.0, "out": 2.0}) assert sae.cfg.norm_activation == "inference" @@ -237,6 +225,8 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder): - sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0}) - output = sae.forward(torch.tensor([[1.0, 2.0]], device=sae_config.device, dtype=sae_config.dtype)) - assert output.shape == (1, 2) + sae.set_dataset_average_activation_norm( + {"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)} + ) + output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype)) + assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))