Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def process_sample(*, feature_acts, context_idx, dataset_name, model_name, shard
"analysis_name": analysis.name,
"interpretation": feature.interpretation,
"dictionary_name": feature.sae_name,
"logits": feature.logits,
"decoder_norms": analysis.decoder_norms,
"decoder_similarity_matrix": analysis.decoder_similarity_matrix,
"decoder_inner_product_matrix": analysis.decoder_inner_product_matrix,
Expand Down
6 changes: 6 additions & 0 deletions src/lm_saes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CLTConfig,
CrossCoderConfig,
DatasetConfig,
DirectLogitAttributorConfig,
FeatureAnalyzerConfig,
InitializerConfig,
LanguageModelConfig,
Expand All @@ -29,6 +30,7 @@
AnalyzeSAESettings,
AutoInterpSettings,
CheckActivationConsistencySettings,
DirectLogitAttributeSettings,
EvaluateCrossCoderSettings,
EvaluateSAESettings,
GenerateActivationsSettings,
Expand All @@ -41,6 +43,7 @@
analyze_sae,
auto_interp,
check_activation_consistency,
direct_logit_attribute,
evaluate_crosscoder,
evaluate_sae,
generate_activations,
Expand Down Expand Up @@ -103,4 +106,7 @@
"sweep_sae",
"LLaDAConfig",
"train_crosscoder",
"DirectLogitAttributeSettings",
"direct_logit_attribute",
"DirectLogitAttributorConfig",
]
73 changes: 73 additions & 0 deletions src/lm_saes/analysis/direct_logit_attributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import einops
import torch
from transformer_lens import HookedTransformer

from lm_saes.abstract_sae import AbstractSparseAutoEncoder
from lm_saes.backend import LanguageModel
from lm_saes.backend.language_model import TransformerLensLanguageModel
from lm_saes.config import DirectLogitAttributorConfig
from lm_saes.crosscoder import CrossCoder
from lm_saes.sae import SparseAutoEncoder


class DirectLogitAttributor:
def __init__(self, cfg: DirectLogitAttributorConfig):
self.cfg = cfg

@torch.no_grad()
def direct_logit_attribute(self, sae: AbstractSparseAutoEncoder, model: LanguageModel):
assert isinstance(model, TransformerLensLanguageModel), (
"DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
)
model: HookedTransformer | None = model.model
assert model is not None, "Model ckpt must be loaded for direct logit attribution"

if isinstance(sae, CrossCoder):
residual = sae.W_D[-1]
elif isinstance(sae, SparseAutoEncoder):
residual = sae.W_D
else:
raise ValueError(f"Unsupported SAE type: {type(sae)}")

residual = einops.rearrange(residual, "batch d_model -> batch 1 d_model") # Add a context dimension

if model.cfg.normalization_type is not None:
residual = model.ln_final(residual) # [batch, pos, d_model]
logits = model.unembed(residual) # [batch, pos, d_vocab]
logits = einops.rearrange(logits, "batch 1 d_vocab -> batch d_vocab") # Remove the context dimension

# Select the top k tokens
top_k_logits, top_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1)
top_k_tokens = [model.to_str_tokens(top_k_indices[i]) for i in range(sae.cfg.d_sae)]

assert top_k_logits.shape == top_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), (
f"Top k logits and indices should have shape (d_sae, top_k), but got {top_k_logits.shape} and {top_k_indices.shape}"
)
assert (len(top_k_tokens), len(top_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), (
f"Top k tokens should have shape (d_sae, top_k), but got {len(top_k_tokens)} and {len(top_k_tokens[0])}"
)

# Select the bottom k tokens
bottom_k_logits, bottom_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1, largest=False)
bottom_k_tokens = [model.to_str_tokens(bottom_k_indices[i]) for i in range(sae.cfg.d_sae)]

assert bottom_k_logits.shape == bottom_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), (
f"Bottom k logits and indices should have shape (d_sae, top_k), but got {bottom_k_logits.shape} and {bottom_k_indices.shape}"
)
assert (len(bottom_k_tokens), len(bottom_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), (
f"Bottom k tokens should have shape (d_sae, top_k), but got {len(bottom_k_tokens)} and {len(bottom_k_tokens[0])}"
)

result = [
{
"top_positive": [
{"token": token, "logit": logit} for token, logit in zip(top_k_tokens[i], top_k_logits[i].tolist())
],
"top_negative": [
{"token": token, "logit": logit}
for token, logit in zip(bottom_k_tokens[i], bottom_k_logits[i].tolist())
],
}
for i in range(sae.cfg.d_sae)
]
return result
31 changes: 0 additions & 31 deletions src/lm_saes/analysis/features_to_logits.py

This file was deleted.

5 changes: 5 additions & 0 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ class FeatureAnalyzerConfig(BaseConfig):
"""


class DirectLogitAttributorConfig(BaseConfig):
top_k: int = 10
""" The number of top tokens to attribute to. """


class WandbConfig(BaseConfig):
wandb_project: str = "gpt2-sae-training"
exp_name: Optional[str] = None
Expand Down
12 changes: 12 additions & 0 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class FeatureRecord(BaseModel):
sae_series: str
index: int
analyses: list[FeatureAnalysis] = []
logits: Optional[dict[str, list[dict[str, Any]]]] = None
interpretation: Optional[dict[str, Any]] = None
metric: Optional[dict[str, float]] = None

Expand Down Expand Up @@ -402,6 +403,17 @@ def update_feature(self, sae_name: str, feature_index: int, update_data: dict, s

return result

def update_features(self, sae_name: str, sae_series: str, update_data: list[dict], start_idx: int = 0):
operations = []
for i, feature_update in enumerate(update_data):
update_operation = pymongo.UpdateOne(
{"sae_name": sae_name, "sae_series": sae_series, "index": start_idx + i},
{"$set": feature_update},
)
operations.append(update_operation)
if operations:
self.feature_collection.bulk_write(operations)

def add_bookmark(
self,
sae_name: str,
Expand Down
4 changes: 4 additions & 0 deletions src/lm_saes/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from .analyze import (
AnalyzeCrossCoderSettings,
AnalyzeSAESettings,
DirectLogitAttributeSettings,
analyze_crosscoder,
analyze_sae,
direct_logit_attribute,
)
from .autointerp import AutoInterpSettings, auto_interp
from .eval import (
Expand Down Expand Up @@ -33,6 +35,8 @@
from .utils import load_config

__all__ = [
"DirectLogitAttributeSettings",
"direct_logit_attribute",
"GenerateActivationsSettings",
"generate_activations",
"CheckActivationConsistencySettings",
Expand Down
106 changes: 106 additions & 0 deletions src/lm_saes/runners/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,26 @@
from torch.distributed.device_mesh import init_device_mesh

from lm_saes.activation.factory import ActivationFactory
from lm_saes.analysis.direct_logit_attributor import DirectLogitAttributor
from lm_saes.analysis.feature_analyzer import FeatureAnalyzer
from lm_saes.backend.language_model import TransformerLensLanguageModel
from lm_saes.config import (
ActivationFactoryConfig,
BaseSAEConfig,
CrossCoderConfig,
DirectLogitAttributorConfig,
FeatureAnalyzerConfig,
LanguageModelConfig,
MongoDBConfig,
SAEConfig,
)
from lm_saes.crosscoder import CrossCoder
from lm_saes.database import MongoClient
from lm_saes.resource_loaders import load_model
from lm_saes.runners.utils import load_config
from lm_saes.sae import SparseAutoEncoder
from lm_saes.utils.logging import get_distributed_logger, setup_logging
from lm_saes.utils.misc import is_master

logger = get_distributed_logger("runners.analyze")

Expand Down Expand Up @@ -199,3 +206,102 @@ def analyze_crosscoder(settings: AnalyzeCrossCoderSettings) -> None:
)

logger.info("CrossCoder analysis completed successfully")


class DirectLogitAttributeSettings(BaseSettings):
"""Settings for analyzing a CrossCoder model."""

sae: BaseSAEConfig
"""Configuration for the SAE model architecture and parameters"""

sae_name: str
"""Name of the SAE model. Use as identifier for the SAE model in the database."""

sae_series: str
"""Series of the SAE model. Use as identifier for the SAE model in the database."""

model: Optional[LanguageModelConfig] = None
"""Configuration for the language model."""

model_name: str
"""Name of the language model."""

direct_logit_attributor: DirectLogitAttributorConfig
"""Configuration for the direct logit attributor."""

mongo: MongoDBConfig
"""Configuration for the MongoDB database."""

device_type: str = "cuda"
"""Device type to use for distributed training ('cuda' or 'cpu')"""

# model_parallel_size: int = 1
# """Size of model parallel (tensor parallel) mesh"""

# data_parallel_size: int = 1
# """Size of data parallel mesh"""

# head_parallel_size: int = 1
# """Size of head parallel mesh"""


@torch.no_grad()
def direct_logit_attribute(settings: DirectLogitAttributeSettings) -> None:
"""Direct logit attribute a SAE model.

Args:
settings: Configuration settings for DirectLogitAttributor
"""
# Set up logging
setup_logging(level="INFO")

# device_mesh = (
# init_device_mesh(
# device_type=settings.device_type,
# mesh_shape=(settings.head_parallel_size, settings.data_parallel_size, settings.model_parallel_size),
# mesh_dim_names=("head", "data", "model"),
# )
# if settings.head_parallel_size > 1 or settings.data_parallel_size > 1 or settings.model_parallel_size > 1
# else None
# )

mongo_client = MongoClient(settings.mongo)
logger.info("MongoDB client initialized")

logger.info("Loading SAE model")
if isinstance(settings.sae, CrossCoderConfig):
sae = CrossCoder.from_config(settings.sae)
elif isinstance(settings.sae, SAEConfig):
sae = SparseAutoEncoder.from_config(settings.sae)
else:
raise ValueError(f"Unsupported SAE config type: {type(settings.sae)}")

# Load configurations
model_cfg = load_config(
config=settings.model,
name=settings.model_name,
mongo_client=mongo_client,
config_type="model",
required=True,
)
model_cfg.device = settings.device_type
model_cfg.dtype = sae.cfg.dtype

model = load_model(model_cfg)
assert isinstance(model, TransformerLensLanguageModel), (
"DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
)

logger.info("Direct logit attribution")
direct_logit_attributor = DirectLogitAttributor(settings.direct_logit_attributor)
results = direct_logit_attributor.direct_logit_attribute(sae, model)

if is_master():
logger.info("Direct logit attribution completed, saving results to MongoDB")
mongo_client.update_features(
sae_name=settings.sae_name,
sae_series=settings.sae_series,
update_data=[{"logits": result} for result in results],
start_idx=0,
)
logger.info("Direct logit attribution completed successfully")
Loading