@@ -32,23 +32,30 @@ def generator(sae_config: SAEConfig) -> torch.Generator:
3232@pytest .fixture
3333def sae (sae_config : SAEConfig , generator : torch .Generator ) -> SparseAutoEncoder :
3434 sae = SparseAutoEncoder (sae_config )
35- sae .encoder .weight .data = torch .randn (
36- sae_config .d_sae , sae_config .d_model , generator = generator , device = sae_config .device , dtype = sae_config .dtype
35+ sae .encoder .weight .data = torch .tensor (
36+ [[1.0 , 2.0 ], [3.0 , 4.0 ], [5.0 , 6.0 ], [7.0 , 8.0 ]],
37+ requires_grad = True ,
38+ dtype = sae_config .dtype ,
39+ device = sae_config .device ,
3740 )
38- sae .decoder .weight .data = torch .randn (
39- sae_config .d_model , sae_config .d_sae , generator = generator , device = sae_config .device , dtype = sae_config .dtype
41+ sae .encoder .bias .data = torch .tensor (
42+ [3.0 , 2.0 , 3.0 , 4.0 ],
43+ requires_grad = True ,
44+ dtype = sae_config .dtype ,
45+ device = sae_config .device ,
46+ )
47+ sae .decoder .weight .data = torch .tensor (
48+ [[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]],
49+ requires_grad = True ,
50+ dtype = sae_config .dtype ,
51+ device = sae_config .device ,
52+ )
53+ sae .decoder .bias .data = torch .tensor (
54+ [1.0 , 2.0 ],
55+ requires_grad = True ,
56+ dtype = sae_config .dtype ,
57+ device = sae_config .device ,
4058 )
41- if sae_config .use_decoder_bias :
42- sae .decoder .bias .data = torch .randn (
43- sae_config .d_model , generator = generator , device = sae_config .device , dtype = sae_config .dtype
44- )
45- if sae_config .use_glu_encoder :
46- sae .encoder_glu .weight .data = torch .randn (
47- sae_config .d_sae , sae_config .d_model , generator = generator , device = sae_config .device , dtype = sae_config .dtype
48- )
49- sae .encoder_glu .bias .data = torch .randn (
50- sae_config .d_sae , generator = generator , device = sae_config .device , dtype = sae_config .dtype
51- )
5259 return sae
5360
5461
@@ -196,27 +203,8 @@ def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder):
196203
197204def test_standardize_parameters_of_dataset_norm (sae_config : SAEConfig , sae : SparseAutoEncoder ):
198205 sae_config .norm_activation = "dataset-wise"
199- sae .encoder .bias .data = torch .tensor (
200- [[1.0 , 2.0 ]],
201- requires_grad = True ,
202- dtype = sae_config .dtype ,
203- device = sae_config .device ,
204- )
205206 encoder_bias_data = sae .encoder .bias .data .clone ()
206- sae .decoder .weight .data = torch .tensor (
207- [[1.0 , 2.0 , 3.0 , 4.0 ], [5.0 , 6.0 , 7.0 , 8.0 ]],
208- requires_grad = True ,
209- dtype = sae_config .dtype ,
210- device = sae_config .device ,
211- )
212207 decoder_weight_data = sae .decoder .weight .data .clone ()
213- if sae_config .use_decoder_bias :
214- sae .decoder .bias .data = torch .tensor (
215- [[1.0 , 2.0 , 3.0 , 4.0 ]],
216- requires_grad = True ,
217- dtype = sae_config .dtype ,
218- device = sae_config .device ,
219- )
220208 decoder_bias_data = sae .decoder .bias .data .clone ()
221209 sae .standardize_parameters_of_dataset_norm ({"in" : 3.0 , "out" : 2.0 })
222210 assert sae .cfg .norm_activation == "inference"
@@ -237,6 +225,8 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar
237225
238226
239227def test_forward (sae_config : SAEConfig , sae : SparseAutoEncoder ):
240- sae .set_dataset_average_activation_norm ({"in" : 3.0 , "out" : 2.0 })
241- output = sae .forward (torch .tensor ([[1.0 , 2.0 ]], device = sae_config .device , dtype = sae_config .dtype ))
242- assert output .shape == (1 , 2 )
228+ sae .set_dataset_average_activation_norm (
229+ {"in" : 2.0 * math .sqrt (sae_config .d_model ), "out" : 1.0 * math .sqrt (sae_config .d_model )}
230+ )
231+ output = sae .forward (torch .tensor ([[4.0 , 4.0 ]], device = sae_config .device , dtype = sae_config .dtype ))
232+ assert torch .allclose (output , torch .tensor ([[69.0 , 146.0 ]], device = sae_config .device , dtype = sae_config .dtype ))
0 commit comments