Skip to content

Commit 55ea55e

Browse files
committed
feat(analysis): reimplement DirectLogitAttributor and related configurations
1 parent 989f7de commit 55ea55e

File tree

8 files changed

+207
-31
lines changed

8 files changed

+207
-31
lines changed

server/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def process_sample(*, feature_acts, context_idx, dataset_name, model_name, shard
349349
"analysis_name": analysis.name,
350350
"interpretation": feature.interpretation,
351351
"dictionary_name": feature.sae_name,
352+
"logits": feature.logits,
352353
"decoder_norms": analysis.decoder_norms,
353354
"decoder_similarity_matrix": analysis.decoder_similarity_matrix,
354355
"decoder_inner_product_matrix": analysis.decoder_inner_product_matrix,

src/lm_saes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CLTConfig,
1212
CrossCoderConfig,
1313
DatasetConfig,
14+
DirectLogitAttributorConfig,
1415
FeatureAnalyzerConfig,
1516
InitializerConfig,
1617
LanguageModelConfig,
@@ -29,6 +30,7 @@
2930
AnalyzeSAESettings,
3031
AutoInterpSettings,
3132
CheckActivationConsistencySettings,
33+
DirectLogitAttributeSettings,
3234
EvaluateCrossCoderSettings,
3335
EvaluateSAESettings,
3436
GenerateActivationsSettings,
@@ -41,6 +43,7 @@
4143
analyze_sae,
4244
auto_interp,
4345
check_activation_consistency,
46+
direct_logit_attribute,
4447
evaluate_crosscoder,
4548
evaluate_sae,
4649
generate_activations,
@@ -103,4 +106,7 @@
103106
"sweep_sae",
104107
"LLaDAConfig",
105108
"train_crosscoder",
109+
"DirectLogitAttributeSettings",
110+
"direct_logit_attribute",
111+
"DirectLogitAttributorConfig",
106112
]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import einops
2+
import torch
3+
from transformer_lens import HookedTransformer
4+
5+
from lm_saes.abstract_sae import AbstractSparseAutoEncoder
6+
from lm_saes.backend import LanguageModel
7+
from lm_saes.backend.language_model import TransformerLensLanguageModel
8+
from lm_saes.config import DirectLogitAttributorConfig
9+
from lm_saes.crosscoder import CrossCoder
10+
from lm_saes.sae import SparseAutoEncoder
11+
12+
13+
class DirectLogitAttributor:
14+
def __init__(self, cfg: DirectLogitAttributorConfig):
15+
self.cfg = cfg
16+
17+
@torch.no_grad()
18+
def direct_logit_attribute(self, sae: AbstractSparseAutoEncoder, model: LanguageModel):
19+
assert isinstance(model, TransformerLensLanguageModel), (
20+
"DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
21+
)
22+
model: HookedTransformer | None = model.model
23+
assert model is not None, "Model ckpt must be loaded for direct logit attribution"
24+
25+
if isinstance(sae, CrossCoder):
26+
residual = sae.W_D[-1]
27+
elif isinstance(sae, SparseAutoEncoder):
28+
residual = sae.W_D
29+
else:
30+
raise ValueError(f"Unsupported SAE type: {type(sae)}")
31+
32+
residual = einops.rearrange(residual, "batch d_model -> batch 1 d_model") # Add a context dimension
33+
34+
if model.cfg.normalization_type is not None:
35+
residual = model.ln_final(residual) # [batch, pos, d_model]
36+
logits = model.unembed(residual) # [batch, pos, d_vocab]
37+
logits = einops.rearrange(logits, "batch 1 d_vocab -> batch d_vocab") # Remove the context dimension
38+
39+
# Select the top k tokens
40+
top_k_logits, top_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1)
41+
top_k_tokens = [model.to_str_tokens(top_k_indices[i]) for i in range(sae.cfg.d_sae)]
42+
43+
assert top_k_logits.shape == top_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), (
44+
f"Top k logits and indices should have shape (d_sae, top_k), but got {top_k_logits.shape} and {top_k_indices.shape}"
45+
)
46+
assert (len(top_k_tokens), len(top_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), (
47+
f"Top k tokens should have shape (d_sae, top_k), but got {len(top_k_tokens)} and {len(top_k_tokens[0])}"
48+
)
49+
50+
# Select the bottom k tokens
51+
bottom_k_logits, bottom_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1, largest=False)
52+
bottom_k_tokens = [model.to_str_tokens(bottom_k_indices[i]) for i in range(sae.cfg.d_sae)]
53+
54+
assert bottom_k_logits.shape == bottom_k_indices.shape == (sae.cfg.d_sae, self.cfg.top_k), (
55+
f"Bottom k logits and indices should have shape (d_sae, top_k), but got {bottom_k_logits.shape} and {bottom_k_indices.shape}"
56+
)
57+
assert (len(bottom_k_tokens), len(bottom_k_tokens[0])) == (sae.cfg.d_sae, self.cfg.top_k), (
58+
f"Bottom k tokens should have shape (d_sae, top_k), but got {len(bottom_k_tokens)} and {len(bottom_k_tokens[0])}"
59+
)
60+
61+
result = [
62+
{
63+
"top_positive": [
64+
{"token": token, "logit": logit} for token, logit in zip(top_k_tokens[i], top_k_logits[i].tolist())
65+
],
66+
"top_negative": [
67+
{"token": token, "logit": logit}
68+
for token, logit in zip(bottom_k_tokens[i], bottom_k_logits[i].tolist())
69+
],
70+
}
71+
for i in range(sae.cfg.d_sae)
72+
]
73+
return result

src/lm_saes/analysis/features_to_logits.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/lm_saes/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ class FeatureAnalyzerConfig(BaseConfig):
431431
"""
432432

433433

434+
class DirectLogitAttributorConfig(BaseConfig):
435+
top_k: int = 10
436+
""" The number of top tokens to attribute to. """
437+
438+
434439
class WandbConfig(BaseConfig):
435440
wandb_project: str = "gpt2-sae-training"
436441
exp_name: Optional[str] = None

src/lm_saes/database.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class FeatureRecord(BaseModel):
5454
sae_series: str
5555
index: int
5656
analyses: list[FeatureAnalysis] = []
57+
logits: Optional[dict[str, list[dict[str, Any]]]] = None
5758
interpretation: Optional[dict[str, Any]] = None
5859
metric: Optional[dict[str, float]] = None
5960

@@ -402,6 +403,17 @@ def update_feature(self, sae_name: str, feature_index: int, update_data: dict, s
402403

403404
return result
404405

406+
def update_features(self, sae_name: str, sae_series: str, update_data: list[dict], start_idx: int = 0):
407+
operations = []
408+
for i, feature_update in enumerate(update_data):
409+
update_operation = pymongo.UpdateOne(
410+
{"sae_name": sae_name, "sae_series": sae_series, "index": start_idx + i},
411+
{"$set": feature_update},
412+
)
413+
operations.append(update_operation)
414+
if operations:
415+
self.feature_collection.bulk_write(operations)
416+
405417
def add_bookmark(
406418
self,
407419
sae_name: str,

src/lm_saes/runners/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from .analyze import (
44
AnalyzeCrossCoderSettings,
55
AnalyzeSAESettings,
6+
DirectLogitAttributeSettings,
67
analyze_crosscoder,
78
analyze_sae,
9+
direct_logit_attribute,
810
)
911
from .autointerp import AutoInterpSettings, auto_interp
1012
from .eval import (
@@ -33,6 +35,8 @@
3335
from .utils import load_config
3436

3537
__all__ = [
38+
"DirectLogitAttributeSettings",
39+
"direct_logit_attribute",
3640
"GenerateActivationsSettings",
3741
"generate_activations",
3842
"CheckActivationConsistencySettings",

src/lm_saes/runners/analyze.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,26 @@
77
from torch.distributed.device_mesh import init_device_mesh
88

99
from lm_saes.activation.factory import ActivationFactory
10+
from lm_saes.analysis.direct_logit_attributor import DirectLogitAttributor
1011
from lm_saes.analysis.feature_analyzer import FeatureAnalyzer
12+
from lm_saes.backend.language_model import TransformerLensLanguageModel
1113
from lm_saes.config import (
1214
ActivationFactoryConfig,
1315
BaseSAEConfig,
1416
CrossCoderConfig,
17+
DirectLogitAttributorConfig,
1518
FeatureAnalyzerConfig,
1619
LanguageModelConfig,
1720
MongoDBConfig,
21+
SAEConfig,
1822
)
1923
from lm_saes.crosscoder import CrossCoder
2024
from lm_saes.database import MongoClient
25+
from lm_saes.resource_loaders import load_model
26+
from lm_saes.runners.utils import load_config
2127
from lm_saes.sae import SparseAutoEncoder
2228
from lm_saes.utils.logging import get_distributed_logger, setup_logging
29+
from lm_saes.utils.misc import is_master
2330

2431
logger = get_distributed_logger("runners.analyze")
2532

@@ -199,3 +206,102 @@ def analyze_crosscoder(settings: AnalyzeCrossCoderSettings) -> None:
199206
)
200207

201208
logger.info("CrossCoder analysis completed successfully")
209+
210+
211+
class DirectLogitAttributeSettings(BaseSettings):
212+
"""Settings for analyzing a CrossCoder model."""
213+
214+
sae: BaseSAEConfig
215+
"""Configuration for the SAE model architecture and parameters"""
216+
217+
sae_name: str
218+
"""Name of the SAE model. Use as identifier for the SAE model in the database."""
219+
220+
sae_series: str
221+
"""Series of the SAE model. Use as identifier for the SAE model in the database."""
222+
223+
model: Optional[LanguageModelConfig] = None
224+
"""Configuration for the language model."""
225+
226+
model_name: str
227+
"""Name of the language model."""
228+
229+
direct_logit_attributor: DirectLogitAttributorConfig
230+
"""Configuration for the direct logit attributor."""
231+
232+
mongo: MongoDBConfig
233+
"""Configuration for the MongoDB database."""
234+
235+
device_type: str = "cuda"
236+
"""Device type to use for distributed training ('cuda' or 'cpu')"""
237+
238+
# model_parallel_size: int = 1
239+
# """Size of model parallel (tensor parallel) mesh"""
240+
241+
# data_parallel_size: int = 1
242+
# """Size of data parallel mesh"""
243+
244+
# head_parallel_size: int = 1
245+
# """Size of head parallel mesh"""
246+
247+
248+
@torch.no_grad()
249+
def direct_logit_attribute(settings: DirectLogitAttributeSettings) -> None:
250+
"""Direct logit attribute a SAE model.
251+
252+
Args:
253+
settings: Configuration settings for DirectLogitAttributor
254+
"""
255+
# Set up logging
256+
setup_logging(level="INFO")
257+
258+
# device_mesh = (
259+
# init_device_mesh(
260+
# device_type=settings.device_type,
261+
# mesh_shape=(settings.head_parallel_size, settings.data_parallel_size, settings.model_parallel_size),
262+
# mesh_dim_names=("head", "data", "model"),
263+
# )
264+
# if settings.head_parallel_size > 1 or settings.data_parallel_size > 1 or settings.model_parallel_size > 1
265+
# else None
266+
# )
267+
268+
mongo_client = MongoClient(settings.mongo)
269+
logger.info("MongoDB client initialized")
270+
271+
logger.info("Loading SAE model")
272+
if isinstance(settings.sae, CrossCoderConfig):
273+
sae = CrossCoder.from_config(settings.sae)
274+
elif isinstance(settings.sae, SAEConfig):
275+
sae = SparseAutoEncoder.from_config(settings.sae)
276+
else:
277+
raise ValueError(f"Unsupported SAE config type: {type(settings.sae)}")
278+
279+
# Load configurations
280+
model_cfg = load_config(
281+
config=settings.model,
282+
name=settings.model_name,
283+
mongo_client=mongo_client,
284+
config_type="model",
285+
required=True,
286+
)
287+
model_cfg.device = settings.device_type
288+
model_cfg.dtype = sae.cfg.dtype
289+
290+
model = load_model(model_cfg)
291+
assert isinstance(model, TransformerLensLanguageModel), (
292+
"DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
293+
)
294+
295+
logger.info("Direct logit attribution")
296+
direct_logit_attributor = DirectLogitAttributor(settings.direct_logit_attributor)
297+
results = direct_logit_attributor.direct_logit_attribute(sae, model)
298+
299+
if is_master():
300+
logger.info("Direct logit attribution completed, saving results to MongoDB")
301+
mongo_client.update_features(
302+
sae_name=settings.sae_name,
303+
sae_series=settings.sae_series,
304+
update_data=[{"logits": result} for result in results],
305+
start_idx=0,
306+
)
307+
logger.info("Direct logit attribution completed successfully")

0 commit comments

Comments
 (0)