Skip to content

Commit dccfb80

Browse files
Frankstein73dest1n1s
authored andcommitted
chore: update TransformerLens submodule and remove MixCoder references
- Updated the TransformerLens submodule to the latest commit. - Removed all references to `MixCoder` from the codebase, including imports and related configurations. - Adjusted `BaseSAEConfig` and `SAEConfig` to eliminate `mixcoder` from the `sae_type` options. - Cleaned up the `FeatureAnalyzer` class by removing modality-specific metrics related to `MixCoder`.
1 parent 087a850 commit dccfb80

File tree

7 files changed

+5
-243
lines changed

7 files changed

+5
-243
lines changed

src/lm_saes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
InitializerConfig,
1414
LanguageModelConfig,
1515
LLaDAConfig,
16-
MixCoderConfig,
1716
MongoDBConfig,
1817
SAEConfig,
1918
TrainerConfig,
@@ -22,7 +21,6 @@
2221
from .crosscoder import CrossCoder
2322
from .database import MongoClient
2423
from .evaluator import EvalConfig, Evaluator
25-
from .mixcoder import MixCoder
2624
from .resource_loaders import load_dataset, load_model
2725
from .runners import (
2826
AnalyzeCrossCoderSettings,
@@ -83,8 +81,6 @@
8381
"FeatureAnalyzerConfig",
8482
"MongoDBConfig",
8583
"MongoClient",
86-
"MixCoderConfig",
87-
"MixCoder",
8884
"AnalyzeCrossCoderSettings",
8985
"analyze_crosscoder",
9086
"AutoInterpSettings",

src/lm_saes/analysis/feature_analyzer.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from lm_saes.abstract_sae import AbstractSparseAutoEncoder
1111
from lm_saes.config import FeatureAnalyzerConfig
1212
from lm_saes.crosscoder import CrossCoder
13-
from lm_saes.mixcoder import MixCoder
1413
from lm_saes.utils.discrete import KeyedDiscreteMapper
1514
from lm_saes.utils.distributed import DimMap
1615
from lm_saes.utils.misc import is_primary_rank
@@ -245,18 +244,6 @@ def analyze_chunk(
245244
max_feature_acts = torch.zeros((d_sae_local,), dtype=sae.cfg.dtype, device=sae.cfg.device)
246245
mapper = KeyedDiscreteMapper()
247246

248-
if isinstance(sae, MixCoder):
249-
act_times_modalities = {
250-
k: torch.zeros((d_sae_local,), dtype=torch.long, device=sae.cfg.device) for k in sae.cfg.modality_names
251-
}
252-
max_feature_acts_modalities = {
253-
k: torch.zeros((d_sae_local,), dtype=sae.cfg.dtype, device=sae.cfg.device)
254-
for k in sae.cfg.modality_names
255-
}
256-
else:
257-
act_times_modalities = None
258-
max_feature_acts_modalities = None
259-
260247
# Process activation batches
261248
for batch in activation_stream:
262249
# Reshape meta to zip outer dimensions to inner
@@ -295,15 +282,6 @@ def analyze_chunk(
295282
act_times += feature_acts.gt(0.0).sum(dim=[0, 1])
296283
max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)
297284

298-
if isinstance(sae, MixCoder):
299-
assert act_times_modalities is not None and max_feature_acts_modalities is not None
300-
for i, k in enumerate(sae.cfg.modality_names):
301-
feature_acts_modality = feature_acts * (batch["modalities"] == i).long().unsqueeze(-1)
302-
act_times_modalities[k] += feature_acts_modality.gt(0.0).sum(dim=[0, 1])
303-
max_feature_acts_modalities[k] = torch.max(
304-
max_feature_acts_modalities[k], feature_acts_modality.max(dim=0).values.max(dim=0).values
305-
)
306-
307285
# TODO: Filter out meta that is not string
308286
discrete_meta = {
309287
k: torch.tensor(mapper.encode(k, v), device=sae.cfg.device, dtype=torch.int32) for k, v in meta.items()
@@ -339,8 +317,6 @@ def analyze_chunk(
339317
max_feature_acts=max_feature_acts,
340318
sample_result=sample_result,
341319
mapper=mapper,
342-
act_times_modalities=act_times_modalities,
343-
max_feature_acts_modalities=max_feature_acts_modalities,
344320
device_mesh=device_mesh,
345321
)
346322

@@ -352,8 +328,6 @@ def _format_analysis_results(
352328
max_feature_acts: torch.Tensor,
353329
sample_result: dict[str, dict[str, torch.Tensor]],
354330
mapper: KeyedDiscreteMapper,
355-
act_times_modalities: dict[str, torch.Tensor] | None = None,
356-
max_feature_acts_modalities: dict[str, torch.Tensor] | None = None,
357331
device_mesh: DeviceMesh | None = None,
358332
) -> list[dict[str, Any]]:
359333
"""Format the analysis results into the final per-feature format.
@@ -365,8 +339,6 @@ def _format_analysis_results(
365339
max_feature_acts: Tensor of maximum activation values for each feature
366340
sample_result: Dictionary of sampling results
367341
mapper: MetaMapper for encoding/decoding metadata
368-
act_times_modalities: Optional dictionary of activation times per modality (for MixCoder)
369-
max_feature_acts_modalities: Optional dictionary of maximum activation values per modality (for MixCoder)
370342
371343
Returns:
372344
List of dictionaries containing per-feature analysis results
@@ -466,16 +438,6 @@ def _format_analysis_results(
466438
feature_result["decoder_similarity_matrix"] = decoder_similarity_matrices[i, :, :].tolist()
467439
feature_result["decoder_inner_product_matrix"] = decoder_inner_product_matrices[i, :, :].tolist()
468440

469-
# Add modality-specific metrics for MixCoder
470-
if (
471-
isinstance(sae, MixCoder)
472-
and act_times_modalities is not None
473-
and max_feature_acts_modalities is not None
474-
):
475-
feature_result["act_times_modalities"] = {k: v[i].item() for k, v in act_times_modalities.items()}
476-
feature_result["max_feature_acts_modalities"] = {
477-
k: v[i].item() for k, v in max_feature_acts_modalities.items()
478-
}
479441

480442
results.append(feature_result)
481443

src/lm_saes/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ class BaseSAEConfig(BaseModelConfig, ABC):
4848
"""
4949
Base class for SAE configs.
5050
Initializer will initialize SAE based on config type.
51-
So this class should not be used directly but only as a base config class for other SAE variants like SAEConfig, MixCoderConfig, CrossCoderConfig, etc.
51+
So this class should not be used directly but only as a base config class for other SAE variants like SAEConfig, CrossCoderConfig, etc.
5252
"""
5353

54-
sae_type: Literal["sae", "crosscoder", "mixcoder"]
54+
sae_type: Literal["sae", "crosscoder"]
5555
d_model: int
5656
expansion_factor: int
5757
use_decoder_bias: bool = True
@@ -114,7 +114,7 @@ def associated_hook_points(self) -> list[str]:
114114

115115

116116
class SAEConfig(BaseSAEConfig):
117-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "sae"
117+
sae_type: Literal["sae", "crosscoder"] = "sae"
118118
hook_point_in: str
119119
hook_point_out: str = Field(default_factory=lambda validated_model: validated_model["hook_point_in"])
120120
use_glu_encoder: bool = False
@@ -125,7 +125,7 @@ def associated_hook_points(self) -> list[str]:
125125

126126

127127
class CrossCoderConfig(BaseSAEConfig):
128-
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "crosscoder"
128+
sae_type: Literal["sae", "crosscoder"] = "crosscoder"
129129
hook_points: list[str]
130130

131131
@property

src/lm_saes/mixcoder.py

Lines changed: 0 additions & 190 deletions
This file was deleted.

src/lm_saes/runners/analyze.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ class AnalyzeSAESettings(BaseSettings):
4242
model: Optional[LanguageModelConfig] = None
4343
"""Configuration for the language model. Required if using dataset sources."""
4444

45-
model_name: Optional[str] = None
46-
"""Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices."""
47-
4845
analyzer: FeatureAnalyzerConfig
4946
"""Configuration for feature analysis"""
5047

src/lm_saes/runners/eval.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ class EvaluateSAESettings(BaseSettings):
4343
model: Optional[LanguageModelConfig] = None
4444
"""Configuration for the language model. Required if using dataset sources."""
4545

46-
model_name: Optional[str] = None
47-
"""Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices."""
48-
4946
eval: EvalConfig
5047
"""Configuration for evaluation"""
5148

0 commit comments

Comments
 (0)