Skip to content

Commit 94242f6

Browse files
committed
Fix: type check
1 parent a39072e commit 94242f6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/lm_saes/sae.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformer_lens.hook_points import HookPoint
1111
from typing_extensions import override
1212

13+
from lm_saes.activation_functions import JumpReLU
1314
from lm_saes.utils.distributed import DimMap
1415
from lm_saes.utils.logging import get_distributed_logger
1516

@@ -448,6 +449,7 @@ def from_saelens(cls, sae_saelens):
448449
d_model = sae_saelens.cfg.d_in
449450
d_sae = sae_saelens.cfg.d_sae
450451
hook_name = sae_saelens.cfg.metadata.hook_name
452+
assert isinstance(hook_name, str)
451453
dtype = sae_saelens.W_enc.dtype
452454

453455
rescale_acts_by_decoder_norm = False
@@ -487,6 +489,7 @@ def from_saelens(cls, sae_saelens):
487489
model.b_E.copy_(sae_saelens.b_enc)
488490

489491
if isinstance(sae_saelens, JumpReLUSAE):
492+
assert isinstance(model.activation_function, JumpReLU)
490493
model.activation_function.log_jumprelu_threshold.copy_(torch.log(sae_saelens.threshold.clone().detach()))
491494

492495
return model
@@ -564,6 +567,7 @@ def to_saelens(self, model_name:str='unknown'):
564567
model.b_enc.copy_(self.b_E)
565568

566569
if isinstance(model, JumpReLUSAE):
570+
assert isinstance(self.activation_function, JumpReLU)
567571
model.threshold.copy_(self.activation_function.log_jumprelu_threshold.exp())
568572

569573
return model

0 commit comments

Comments
 (0)