We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a3ea723 commit faec68bCopy full SHA for faec68b
tests/unit/test_sae.py
@@ -229,4 +229,4 @@ def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
229
{"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)}
230
)
231
output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype))
232
- assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))
+ assert torch.allclose(output, torch.tensor([[212.0, 449.0]], device=sae_config.device, dtype=sae_config.dtype))
0 commit comments