Skip to content

Commit c66fa47

Browse files
Frankstein73dest1n1s
authored andcommitted
feat(trainer): add extra log info for mixcoder
1 parent 8d57964 commit c66fa47

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/lm_saes/trainer.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from wandb.sdk.wandb_run import Run
1111

1212
from lm_saes.config import TrainerConfig
13+
from lm_saes.mixcoder import MixCoder
1314
from lm_saes.optim import get_scheduler
1415
from lm_saes.sae import SparseAutoEncoder
1516
from lm_saes.utils.misc import all_reduce_tensor
@@ -134,13 +135,23 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
134135
"sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(),
135136
"sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(),
136137
}
137-
if sae.cfg.sae_type == 'crosscoder':
138-
wandb_log_dict.update({
139-
"sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-1).sum().item(),
140-
"sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-2).sum().item(),
141-
"sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-5).sum().item(),
142-
"sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-6).sum().item(),
143-
})
138+
if sae.cfg.sae_type == "crosscoder":
139+
wandb_log_dict.update(
140+
{
141+
"sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-1)
142+
.sum()
143+
.item(),
144+
"sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-2)
145+
.sum()
146+
.item(),
147+
"sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-5)
148+
.sum()
149+
.item(),
150+
"sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-6)
151+
.sum()
152+
.item(),
153+
}
154+
)
144155

145156
self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1)
146157
log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"])
@@ -173,6 +184,19 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
173184
"details/n_training_tokens": self.cur_tokens,
174185
}
175186
wandb_log_dict.update(sae.log_statistics())
187+
if sae.cfg.sae_type == "mixcoder":
188+
assert isinstance(sae, MixCoder)
189+
for modality, (start, end) in sae.modality_index.items():
190+
wandb_log_dict.update(
191+
{
192+
f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0)
193+
.float()
194+
.sum(-1)
195+
.mean()
196+
.item(),
197+
}
198+
)
199+
176200
self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1)
177201

178202
def _save_checkpoint(self, sae: SparseAutoEncoder):

0 commit comments

Comments
 (0)