From 55ea55ebc0da0dc440e148faeb353eb21336e1b2 Mon Sep 17 00:00:00 2001 From: Xuyang Ge Date: Tue, 8 Jul 2025 17:47:47 +0800 Subject: [PATCH] feat(analysis): reimplement DirectLogitAttributor and related configurations --- server/app.py | 1 + src/lm_saes/__init__.py | 6 + .../analysis/direct_logit_attributor.py | 73 ++++++++++++ src/lm_saes/analysis/features_to_logits.py | 31 ----- src/lm_saes/config.py | 5 + src/lm_saes/database.py | 12 ++ src/lm_saes/runners/__init__.py | 4 + src/lm_saes/runners/analyze.py | 106 ++++++++++++++++++ 8 files changed, 207 insertions(+), 31 deletions(-) create mode 100644 src/lm_saes/analysis/direct_logit_attributor.py delete mode 100644 src/lm_saes/analysis/features_to_logits.py diff --git a/server/app.py b/server/app.py index 3a8f9384..c4d66330 100644 --- a/server/app.py +++ b/server/app.py @@ -349,6 +349,7 @@ def process_sample(*, feature_acts, context_idx, dataset_name, model_name, shard "analysis_name": analysis.name, "interpretation": feature.interpretation, "dictionary_name": feature.sae_name, + "logits": feature.logits, "decoder_norms": analysis.decoder_norms, "decoder_similarity_matrix": analysis.decoder_similarity_matrix, "decoder_inner_product_matrix": analysis.decoder_inner_product_matrix, diff --git a/src/lm_saes/__init__.py b/src/lm_saes/__init__.py index 38d98b86..d4e9c289 100644 --- a/src/lm_saes/__init__.py +++ b/src/lm_saes/__init__.py @@ -11,6 +11,7 @@ CLTConfig, CrossCoderConfig, DatasetConfig, + DirectLogitAttributorConfig, FeatureAnalyzerConfig, InitializerConfig, LanguageModelConfig, @@ -29,6 +30,7 @@ AnalyzeSAESettings, AutoInterpSettings, CheckActivationConsistencySettings, + DirectLogitAttributeSettings, EvaluateCrossCoderSettings, EvaluateSAESettings, GenerateActivationsSettings, @@ -41,6 +43,7 @@ analyze_sae, auto_interp, check_activation_consistency, + direct_logit_attribute, evaluate_crosscoder, evaluate_sae, generate_activations, @@ -103,4 +106,7 @@ "sweep_sae", "LLaDAConfig", "train_crosscoder", + "DirectLogitAttributeSettings", + "direct_logit_attribute", + "DirectLogitAttributorConfig", ] diff --git a/src/lm_saes/analysis/direct_logit_attributor.py b/src/lm_saes/analysis/direct_logit_attributor.py new file mode 100644 index 00000000..787f896e --- /dev/null +++ b/src/lm_saes/analysis/direct_logit_attributor.py @@ -0,0 +1,73 @@ +import einops +import torch +from transformer_lens import HookedTransformer + +from lm_saes.abstract_sae import AbstractSparseAutoEncoder +from lm_saes.backend import LanguageModel +from lm_saes.backend.language_model import TransformerLensLanguageModel +from lm_saes.config import DirectLogitAttributorConfig +from lm_saes.crosscoder import CrossCoder +from lm_saes.sae import SparseAutoEncoder + + +class DirectLogitAttributor: + def __init__(self, cfg: DirectLogitAttributorConfig): + self.cfg = cfg + + @torch.no_grad() + def direct_logit_attribute(self, sae: AbstractSparseAutoEncoder, model: LanguageModel): + assert isinstance(model, TransformerLensLanguageModel), ( + "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend" + ) + model: HookedTransformer | None = model.model + assert model is not None, "Model ckpt must be loaded for direct logit attribution" + + if isinstance(sae, CrossCoder): + residual = sae.W_D[-1] + elif isinstance(sae, SparseAutoEncoder): + residual = sae.W_D + else: + raise ValueError(f"Unsupported SAE type: {type(sae)}") + + residual = einops.rearrange(residual, "batch d_model -> batch 1 d_model") # Add a context dimension + + if model.cfg.normalization_type is not None: + residual = model.ln_final(residual) # [batch, pos, d_model] + logits = model.unembed(residual) # [batch, pos, d_vocab] + logits = einops.rearrange(logits, "batch 1 d_vocab -> batch d_vocab") # Remove the context dimension + + # Select the top k tokens + top_k_logits, top_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1) + top_k_tokens = [model.to_str_tokens(top_k_indices[i]) for i in range(sae.cfg.d_sae)] + + assert top_k_logits.shape == top_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), ( + f"Top k logits and indices should have shape (d_sae, top_k), but got {top_k_logits.shape} and {top_k_indices.shape}" + ) + assert (len(top_k_tokens), len(top_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), ( + f"Top k tokens should have shape (d_sae, top_k), but got {len(top_k_tokens)} and {len(top_k_tokens[0])}" + ) + + # Select the bottom k tokens + bottom_k_logits, bottom_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1, largest=False) + bottom_k_tokens = [model.to_str_tokens(bottom_k_indices[i]) for i in range(sae.cfg.d_sae)] + + assert bottom_k_logits.shape == bottom_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), ( + f"Bottom k logits and indices should have shape (d_sae, top_k), but got {bottom_k_logits.shape} and {bottom_k_indices.shape}" + ) + assert (len(bottom_k_tokens), len(bottom_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), ( + f"Bottom k tokens should have shape (d_sae, top_k), but got {len(bottom_k_tokens)} and {len(bottom_k_tokens[0])}" + ) + + result = [ + { + "top_positive": [ + {"token": token, "logit": logit} for token, logit in zip(top_k_tokens[i], top_k_logits[i].tolist()) + ], + "top_negative": [ + {"token": token, "logit": logit} + for token, logit in zip(bottom_k_tokens[i], bottom_k_logits[i].tolist()) + ], + } + for i in range(sae.cfg.d_sae) + ] + return result diff --git a/src/lm_saes/analysis/features_to_logits.py b/src/lm_saes/analysis/features_to_logits.py deleted file mode 100644 index 33789518..00000000 --- a/src/lm_saes/analysis/features_to_logits.py +++ /dev/null @@ -1,31 +0,0 @@ -# import torch -# from transformer_lens import HookedTransformer - -# from ..config import FeaturesDecoderConfig -# from ..sae import SparseAutoEncoder - - -# @torch.no_grad() -# def features_to_logits(sae: SparseAutoEncoder, model: HookedTransformer, cfg: FeaturesDecoderConfig): -# num_ones = int(torch.sum(sae.feature_act_mask).item()) - -# feature_acts = torch.zeros(num_ones, cfg.sae.d_sae).to(cfg.sae.device) - -# index = 0 -# for i in range(len(sae.feature_act_mask)): -# if sae.feature_act_mask[i] == 1: -# feature_acts[index, i] = 1 -# index += 1 - -# feature_acts = torch.unsqueeze(feature_acts, dim=1) - -# residual = sae.decode(feature_acts) - -# if model.cfg.normalization_type is not None: -# residual = model.ln_final(residual) # [batch, pos, d_model] -# logits = model.unembed(residual) # [batch, pos, d_vocab] - -# active_indices = [i for i, val in enumerate(sae.feature_act_mask) if val == 1] -# result_dict = {str(feature_index): logits[idx][0] for idx, feature_index in enumerate(active_indices)} - -# return result_dict diff --git a/src/lm_saes/config.py b/src/lm_saes/config.py index 5afc3a77..466ffbd8 100644 --- a/src/lm_saes/config.py +++ b/src/lm_saes/config.py @@ -431,6 +431,11 @@ class FeatureAnalyzerConfig(BaseConfig): """ +class DirectLogitAttributorConfig(BaseConfig): + top_k: int = 10 + """ The number of top tokens to attribute to. """ + + class WandbConfig(BaseConfig): wandb_project: str = "gpt2-sae-training" exp_name: Optional[str] = None diff --git a/src/lm_saes/database.py b/src/lm_saes/database.py index 04aee63a..3b752dee 100644 --- a/src/lm_saes/database.py +++ b/src/lm_saes/database.py @@ -54,6 +54,7 @@ class FeatureRecord(BaseModel): sae_series: str index: int analyses: list[FeatureAnalysis] = [] + logits: Optional[dict[str, list[dict[str, Any]]]] = None interpretation: Optional[dict[str, Any]] = None metric: Optional[dict[str, float]] = None @@ -402,6 +403,17 @@ def update_feature(self, sae_name: str, feature_index: int, update_data: dict, s return result + def update_features(self, sae_name: str, sae_series: str, update_data: list[dict], start_idx: int = 0): + operations = [] + for i, feature_update in enumerate(update_data): + update_operation = pymongo.UpdateOne( + {"sae_name": sae_name, "sae_series": sae_series, "index": start_idx + i}, + {"$set": feature_update}, + ) + operations.append(update_operation) + if operations: + self.feature_collection.bulk_write(operations) + def add_bookmark( self, sae_name: str, diff --git a/src/lm_saes/runners/__init__.py b/src/lm_saes/runners/__init__.py index f12fa85a..87fd83b0 100644 --- a/src/lm_saes/runners/__init__.py +++ b/src/lm_saes/runners/__init__.py @@ -3,8 +3,10 @@ from .analyze import ( AnalyzeCrossCoderSettings, AnalyzeSAESettings, + DirectLogitAttributeSettings, analyze_crosscoder, analyze_sae, + direct_logit_attribute, ) from .autointerp import AutoInterpSettings, auto_interp from .eval import ( @@ -33,6 +35,8 @@ from .utils import load_config __all__ = [ + "DirectLogitAttributeSettings", + "direct_logit_attribute", "GenerateActivationsSettings", "generate_activations", "CheckActivationConsistencySettings", diff --git a/src/lm_saes/runners/analyze.py b/src/lm_saes/runners/analyze.py index 6ff0a531..e31583b7 100644 --- a/src/lm_saes/runners/analyze.py +++ b/src/lm_saes/runners/analyze.py @@ -7,19 +7,26 @@ from torch.distributed.device_mesh import init_device_mesh from lm_saes.activation.factory import ActivationFactory +from lm_saes.analysis.direct_logit_attributor import DirectLogitAttributor from lm_saes.analysis.feature_analyzer import FeatureAnalyzer +from lm_saes.backend.language_model import TransformerLensLanguageModel from lm_saes.config import ( ActivationFactoryConfig, BaseSAEConfig, CrossCoderConfig, + DirectLogitAttributorConfig, FeatureAnalyzerConfig, LanguageModelConfig, MongoDBConfig, + SAEConfig, ) from lm_saes.crosscoder import CrossCoder from lm_saes.database import MongoClient +from lm_saes.resource_loaders import load_model +from lm_saes.runners.utils import load_config from lm_saes.sae import SparseAutoEncoder from lm_saes.utils.logging import get_distributed_logger, setup_logging +from lm_saes.utils.misc import is_master logger = get_distributed_logger("runners.analyze") @@ -199,3 +206,102 @@ def analyze_crosscoder(settings: AnalyzeCrossCoderSettings) -> None: ) logger.info("CrossCoder analysis completed successfully") + + +class DirectLogitAttributeSettings(BaseSettings): + """Settings for analyzing a CrossCoder model.""" + + sae: BaseSAEConfig + """Configuration for the SAE model architecture and parameters""" + + sae_name: str + """Name of the SAE model. Use as identifier for the SAE model in the database.""" + + sae_series: str + """Series of the SAE model. Use as identifier for the SAE model in the database.""" + + model: Optional[LanguageModelConfig] = None + """Configuration for the language model.""" + + model_name: str + """Name of the language model.""" + + direct_logit_attributor: DirectLogitAttributorConfig + """Configuration for the direct logit attributor.""" + + mongo: MongoDBConfig + """Configuration for the MongoDB database.""" + + device_type: str = "cuda" + """Device type to use for distributed training ('cuda' or 'cpu')""" + + # model_parallel_size: int = 1 + # """Size of model parallel (tensor parallel) mesh""" + + # data_parallel_size: int = 1 + # """Size of data parallel mesh""" + + # head_parallel_size: int = 1 + # """Size of head parallel mesh""" + + +@torch.no_grad() +def direct_logit_attribute(settings: DirectLogitAttributeSettings) -> None: + """Direct logit attribute a SAE model. + + Args: + settings: Configuration settings for DirectLogitAttributor + """ + # Set up logging + setup_logging(level="INFO") + + # device_mesh = ( + # init_device_mesh( + # device_type=settings.device_type, + # mesh_shape=(settings.head_parallel_size, settings.data_parallel_size, settings.model_parallel_size), + # mesh_dim_names=("head", "data", "model"), + # ) + # if settings.head_parallel_size > 1 or settings.data_parallel_size > 1 or settings.model_parallel_size > 1 + # else None + # ) + + mongo_client = MongoClient(settings.mongo) + logger.info("MongoDB client initialized") + + logger.info("Loading SAE model") + if isinstance(settings.sae, CrossCoderConfig): + sae = CrossCoder.from_config(settings.sae) + elif isinstance(settings.sae, SAEConfig): + sae = SparseAutoEncoder.from_config(settings.sae) + else: + raise ValueError(f"Unsupported SAE config type: {type(settings.sae)}") + + # Load configurations + model_cfg = load_config( + config=settings.model, + name=settings.model_name, + mongo_client=mongo_client, + config_type="model", + required=True, + ) + model_cfg.device = settings.device_type + model_cfg.dtype = sae.cfg.dtype + + model = load_model(model_cfg) + assert isinstance(model, TransformerLensLanguageModel), ( + "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend" + ) + + logger.info("Direct logit attribution") + direct_logit_attributor = DirectLogitAttributor(settings.direct_logit_attributor) + results = direct_logit_attributor.direct_logit_attribute(sae, model) + + if is_master(): + logger.info("Direct logit attribution completed, saving results to MongoDB") + mongo_client.update_features( + sae_name=settings.sae_name, + sae_series=settings.sae_series, + update_data=[{"logits": result} for result in results], + start_idx=0, + ) + logger.info("Direct logit attribution completed successfully")