@@ -20,24 +20,18 @@ class CrossCoder(SparseAutoEncoder):
2020 def __init__ (self , cfg : BaseSAEConfig ):
2121 super (CrossCoder , self ).__init__ (cfg )
2222
23- def _decoder_norm (
24- self ,
25- decoder : torch .nn .Linear ,
26- keepdim : bool = False ,
27- local_only = True ,
28- aggregate = "none"
29- ):
23+ def _decoder_norm (self , decoder : torch .nn .Linear , keepdim : bool = False , local_only = True , aggregate = "none" ):
3024 decoder_norm = super ()._decoder_norm (
3125 decoder = decoder ,
3226 keepdim = keepdim ,
3327 )
3428 if not local_only :
3529 decoder_norm = all_reduce_tensor (
36- decoder_norm ,
30+ decoder_norm ,
3731 aggregate = aggregate ,
3832 )
3933 return decoder_norm
40-
34+
4135 @overload
4236 def encode (
4337 self ,
@@ -110,7 +104,7 @@ def encode(
110104
111105 hidden_pre = all_reduce_tensor (hidden_pre , aggregate = "sum" )
112106 hidden_pre = self .hook_hidden_pre (hidden_pre )
113-
107+
114108 if self .cfg .sparsity_include_decoder_norm :
115109 true_feature_acts = hidden_pre * self ._decoder_norm (
116110 decoder = self .decoder ,
@@ -127,7 +121,7 @@ def encode(
127121 if return_hidden_pre :
128122 return feature_acts , hidden_pre
129123 return feature_acts
130-
124+
131125 @overload
132126 def compute_loss (
133127 self ,
@@ -229,4 +223,3 @@ def initialize_with_same_weight_across_layers(self):
229223 self .encoder .bias .data = get_tensor_from_specific_rank (self .encoder .bias .data .clone (), src = 0 )
230224 self .decoder .weight .data = get_tensor_from_specific_rank (self .decoder .weight .data .clone (), src = 0 )
231225 self .decoder .bias .data = get_tensor_from_specific_rank (self .decoder .bias .data .clone (), src = 0 )
232-
0 commit comments