Skip to content

Commit 8d57964

Browse files
Frankstein73dest1n1s
authored andcommitted
fix(mixcoder): fix topk activation func
1 parent daeb92b commit 8d57964

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

src/lm_saes/mixcoder.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]:
159159
"""
160160
return self.modality_index
161161

162-
def _get_modality_activation(
162+
def _get_modality_activation_mask(
163163
self,
164164
activation: Union[
165165
Float[torch.Tensor, "batch d_model"],
@@ -182,7 +182,7 @@ def _get_modality_activation(
182182
The activation of the specified modality. The shape is the same as the input activation.
183183
"""
184184
activation_mask = torch.isin(tokens, self.modality_indices[modality])
185-
return activation_mask.unsqueeze(1) * activation
185+
return activation_mask.unsqueeze(1)
186186

187187
@overload
188188
def encode(
@@ -266,7 +266,7 @@ def encode(
266266
if modality == "shared":
267267
# shared modality is not encoded directly but summed up during other modalities' encoding
268268
continue
269-
x_modality = self._get_modality_activation(x, tokens, modality)
269+
activation_mask = self._get_modality_activation_mask(x, tokens, modality)
270270
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
271271
modality_bias = (
272272
self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore
@@ -278,15 +278,15 @@ def encode(
278278
if isinstance(self.decoder["shared"].bias, DTensor)
279279
else self.decoder["shared"].bias
280280
)
281-
x_modality = x_modality - modality_bias - shared_bias
281+
x = x - modality_bias - shared_bias
282282

283-
hidden_pre_modality = self.encoder[modality](x_modality)
284-
hidden_pre_shared = self.encoder["shared"](x_modality)
283+
hidden_pre_modality = self.encoder[modality](x)
284+
hidden_pre_shared = self.encoder["shared"](x)
285285

286286
if self.cfg.use_glu_encoder:
287-
hidden_pre_modality_glu = torch.sigmoid(self.encoder_glu[modality](x_modality))
287+
hidden_pre_modality_glu = torch.sigmoid(self.encoder_glu[modality](x))
288288
hidden_pre_modality = hidden_pre_modality * hidden_pre_modality_glu
289-
hidden_pre_shared_glu = torch.sigmoid(self.encoder_glu["shared"](x_modality))
289+
hidden_pre_shared_glu = torch.sigmoid(self.encoder_glu["shared"](x))
290290
hidden_pre_shared = hidden_pre_shared * hidden_pre_shared_glu
291291

292292
if self.cfg.sparsity_include_decoder_norm:
@@ -296,7 +296,9 @@ def encode(
296296
true_feature_acts_modality = hidden_pre_modality
297297
true_feature_acts_shared = hidden_pre_shared
298298

299-
true_feature_acts_concat = torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1)
299+
true_feature_acts_concat = (
300+
torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) * activation_mask
301+
)
300302
activation_mask_concat = self.activation_function(true_feature_acts_concat)
301303
feature_acts_concat = true_feature_acts_concat * activation_mask_concat
302304

@@ -313,6 +315,7 @@ def encode(
313315

314316
hidden_pre = self.hook_hidden_pre(hidden_pre)
315317
feature_acts = self.hook_feature_acts(feature_acts)
318+
# assert torch.all((feature_acts > 0).sum(-1) <= self.current_k)
316319
if return_hidden_pre:
317320
return feature_acts, hidden_pre
318321
return feature_acts

src/lm_saes/sae.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def topk_activation(x: torch.Tensor):
109109
k = x.shape[-1] - self.current_k + 1
110110
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1)
111111
k_th_value = k_th_value.unsqueeze(dim=1)
112-
print()
113112
return x.ge(k_th_value)
114113

115114
return topk_activation

tests/unit/test_mixcoder.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ def config():
1212
modalities={"text": 2, "image": 3, "shared": 4},
1313
device="cpu",
1414
dtype=torch.float32,
15-
use_glu_encoder=False,
15+
use_glu_encoder=True,
1616
use_decoder_bias=True,
1717
hook_point_in="hook_point_in",
1818
hook_point_out="hook_point_out",
1919
expansion_factor=1.0,
20+
top_k=2,
21+
act_fn="topk",
2022
)
2123

2224

@@ -32,6 +34,12 @@ def modality_indices():
3234
def mixcoder(config, modality_indices):
3335
model = MixCoder(config)
3436
model.init_parameters(modality_indices=modality_indices)
37+
model.decoder["text"].bias.data = torch.rand_like(model.decoder["text"].bias.data)
38+
model.decoder["image"].bias.data = torch.rand_like(model.decoder["image"].bias.data)
39+
model.decoder["shared"].bias.data = torch.rand_like(model.decoder["shared"].bias.data)
40+
model.encoder["text"].bias.data = torch.rand_like(model.encoder["text"].bias.data)
41+
model.encoder["image"].bias.data = torch.rand_like(model.encoder["image"].bias.data)
42+
model.encoder["shared"].bias.data = torch.rand_like(model.encoder["shared"].bias.data)
3543
return model
3644

3745

@@ -80,6 +88,7 @@ def test_encode_decode(mixcoder, config):
8088
),
8189
feature_acts[:, slice(*modality_index["shared"])],
8290
)
91+
print(feature_acts)
8392

8493
# Test decode
8594
reconstructed = mixcoder.decode(feature_acts)
@@ -95,18 +104,18 @@ def test_encode_decode(mixcoder, config):
95104
assert torch.allclose(reconstructed_image[4:, :], reconstructed[4:, :])
96105

97106

98-
def test_get_modality_activation(mixcoder, config):
107+
def test_get_modality_activation_mask(mixcoder, config):
99108
"""Test the _get_modality_activation method."""
100109
batch_size = 8
101110
x = torch.ones(batch_size, config.d_model)
102111
tokens = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
103112

104113
# Test text modality
105-
text_activation = mixcoder._get_modality_activation(x, tokens, "text")
106-
assert torch.all(text_activation[0, :4] == 1) # First 4 positions should be 1
107-
assert torch.all(text_activation[0, 4:] == 0) # Last 4 positions should be 0
114+
text_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "text")
115+
assert torch.all(text_activation_mask[0, :4] == 1) # First 4 positions should be 1
116+
assert torch.all(text_activation_mask[0, 4:] == 0) # Last 4 positions should be 0
108117

109118
# Test image modality
110-
image_activation = mixcoder._get_modality_activation(x, tokens, "image")
111-
assert torch.all(image_activation[1, :4] == 0) # First 4 positions should be 0
112-
assert torch.all(image_activation[1, 4:] == 1) # Last 4 positions should be 1
119+
image_activation_mask = mixcoder._get_modality_activation_mask(x, tokens, "image")
120+
assert torch.all(image_activation_mask[1, :4] == 0) # First 4 positions should be 0
121+
assert torch.all(image_activation_mask[1, 4:] == 1) # Last 4 positions should be 1

0 commit comments

Comments
 (0)