|
10 | 10 | from wandb.sdk.wandb_run import Run |
11 | 11 |
|
12 | 12 | from lm_saes.config import TrainerConfig |
| 13 | +from lm_saes.mixcoder import MixCoder |
13 | 14 | from lm_saes.optim import get_scheduler |
14 | 15 | from lm_saes.sae import SparseAutoEncoder |
15 | 16 | from lm_saes.utils.misc import all_reduce_tensor |
@@ -134,13 +135,23 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) |
134 | 135 | "sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(), |
135 | 136 | "sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(), |
136 | 137 | } |
137 | | - if sae.cfg.sae_type == 'crosscoder': |
138 | | - wandb_log_dict.update({ |
139 | | - "sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-1).sum().item(), |
140 | | - "sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate='max') > 1e-2).sum().item(), |
141 | | - "sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-5).sum().item(), |
142 | | - "sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate='max') < 1e-6).sum().item(), |
143 | | - }) |
| 138 | + if sae.cfg.sae_type == "crosscoder": |
| 139 | + wandb_log_dict.update( |
| 140 | + { |
| 141 | + "sparsity/overall_above_1e-1": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-1) |
| 142 | + .sum() |
| 143 | + .item(), |
| 144 | + "sparsity/overall_above_1e-2": (all_reduce_tensor(feature_sparsity, aggregate="max") > 1e-2) |
| 145 | + .sum() |
| 146 | + .item(), |
| 147 | + "sparsity/overall_below_1e-5": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-5) |
| 148 | + .sum() |
| 149 | + .item(), |
| 150 | + "sparsity/overall_below_1e-6": (all_reduce_tensor(feature_sparsity, aggregate="max") < 1e-6) |
| 151 | + .sum() |
| 152 | + .item(), |
| 153 | + } |
| 154 | + ) |
144 | 155 |
|
145 | 156 | self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1) |
146 | 157 | log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"]) |
@@ -173,6 +184,19 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor]) |
173 | 184 | "details/n_training_tokens": self.cur_tokens, |
174 | 185 | } |
175 | 186 | wandb_log_dict.update(sae.log_statistics()) |
| 187 | + if sae.cfg.sae_type == "mixcoder": |
| 188 | + assert isinstance(sae, MixCoder) |
| 189 | + for modality, (start, end) in sae.modality_index.items(): |
| 190 | + wandb_log_dict.update( |
| 191 | + { |
| 192 | + f"metrics/{modality}_l0": (log_info["feature_acts"][:, start:end] > 0) |
| 193 | + .float() |
| 194 | + .sum(-1) |
| 195 | + .mean() |
| 196 | + .item(), |
| 197 | + } |
| 198 | + ) |
| 199 | + |
176 | 200 | self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1) |
177 | 201 |
|
178 | 202 | def _save_checkpoint(self, sae: SparseAutoEncoder): |
|
0 commit comments