@@ -159,7 +159,7 @@ def get_modality_index(self) -> dict[str, tuple[int, int]]:
159159 """
160160 return self .modality_index
161161
162- def _get_modality_activation (
162+ def _get_modality_activation_mask (
163163 self ,
164164 activation : Union [
165165 Float [torch .Tensor , "batch d_model" ],
@@ -182,7 +182,7 @@ def _get_modality_activation(
182182 The activation of the specified modality. The shape is the same as the input activation.
183183 """
184184 activation_mask = torch .isin (tokens , self .modality_indices [modality ])
185- return activation_mask .unsqueeze (1 ) * activation
185+ return activation_mask .unsqueeze (1 )
186186
187187 @overload
188188 def encode (
@@ -266,7 +266,7 @@ def encode(
266266 if modality == "shared" :
267267 # shared modality is not encoded directly but summed up during other modalities' encoding
268268 continue
269- x_modality = self ._get_modality_activation (x , tokens , modality )
269+ activation_mask = self ._get_modality_activation_mask (x , tokens , modality )
270270 if self .cfg .use_decoder_bias and self .cfg .apply_decoder_bias_to_pre_encoder :
271271 modality_bias = (
272272 self .decoder [modality ].bias .to_local () # TODO: check if this is correct # type: ignore
@@ -278,15 +278,15 @@ def encode(
278278 if isinstance (self .decoder ["shared" ].bias , DTensor )
279279 else self .decoder ["shared" ].bias
280280 )
281- x_modality = x_modality - modality_bias - shared_bias
281+ x = x - modality_bias - shared_bias
282282
283- hidden_pre_modality = self .encoder [modality ](x_modality )
284- hidden_pre_shared = self .encoder ["shared" ](x_modality )
283+ hidden_pre_modality = self .encoder [modality ](x )
284+ hidden_pre_shared = self .encoder ["shared" ](x )
285285
286286 if self .cfg .use_glu_encoder :
287- hidden_pre_modality_glu = torch .sigmoid (self .encoder_glu [modality ](x_modality ))
287+ hidden_pre_modality_glu = torch .sigmoid (self .encoder_glu [modality ](x ))
288288 hidden_pre_modality = hidden_pre_modality * hidden_pre_modality_glu
289- hidden_pre_shared_glu = torch .sigmoid (self .encoder_glu ["shared" ](x_modality ))
289+ hidden_pre_shared_glu = torch .sigmoid (self .encoder_glu ["shared" ](x ))
290290 hidden_pre_shared = hidden_pre_shared * hidden_pre_shared_glu
291291
292292 if self .cfg .sparsity_include_decoder_norm :
@@ -296,7 +296,9 @@ def encode(
296296 true_feature_acts_modality = hidden_pre_modality
297297 true_feature_acts_shared = hidden_pre_shared
298298
299- true_feature_acts_concat = torch .cat ([true_feature_acts_modality , true_feature_acts_shared ], dim = 1 )
299+ true_feature_acts_concat = (
300+ torch .cat ([true_feature_acts_modality , true_feature_acts_shared ], dim = 1 ) * activation_mask
301+ )
300302 activation_mask_concat = self .activation_function (true_feature_acts_concat )
301303 feature_acts_concat = true_feature_acts_concat * activation_mask_concat
302304
@@ -313,6 +315,7 @@ def encode(
313315
314316 hidden_pre = self .hook_hidden_pre (hidden_pre )
315317 feature_acts = self .hook_feature_acts (feature_acts )
318+ # assert torch.all((feature_acts > 0).sum(-1) <= self.current_k)
316319 if return_hidden_pre :
317320 return feature_acts , hidden_pre
318321 return feature_acts
0 commit comments