|
10 | 10 | from transformer_lens.hook_points import HookPoint |
11 | 11 | from typing_extensions import override |
12 | 12 |
|
| 13 | +from lm_saes.activation_functions import JumpReLU |
13 | 14 | from lm_saes.utils.distributed import DimMap |
14 | 15 | from lm_saes.utils.logging import get_distributed_logger |
15 | 16 |
|
@@ -448,6 +449,7 @@ def from_saelens(cls, sae_saelens): |
448 | 449 | d_model = sae_saelens.cfg.d_in |
449 | 450 | d_sae = sae_saelens.cfg.d_sae |
450 | 451 | hook_name = sae_saelens.cfg.metadata.hook_name |
| 452 | + assert isinstance(hook_name, str) |
451 | 453 | dtype = sae_saelens.W_enc.dtype |
452 | 454 |
|
453 | 455 | rescale_acts_by_decoder_norm = False |
@@ -487,6 +489,7 @@ def from_saelens(cls, sae_saelens): |
487 | 489 | model.b_E.copy_(sae_saelens.b_enc) |
488 | 490 |
|
489 | 491 | if isinstance(sae_saelens, JumpReLUSAE): |
| 492 | + assert isinstance(model.activation_function, JumpReLU) |
490 | 493 | model.activation_function.log_jumprelu_threshold.copy_(torch.log(sae_saelens.threshold.clone().detach())) |
491 | 494 |
|
492 | 495 | return model |
@@ -564,6 +567,7 @@ def to_saelens(self, model_name:str='unknown'): |
564 | 567 | model.b_enc.copy_(self.b_E) |
565 | 568 |
|
566 | 569 | if isinstance(model, JumpReLUSAE): |
| 570 | + assert isinstance(self.activation_function, JumpReLU) |
567 | 571 | model.threshold.copy_(self.activation_function.log_jumprelu_threshold.exp()) |
568 | 572 |
|
569 | 573 | return model |
0 commit comments