Skip to content

Commit cd0f71e

Browse files
authored
feat(autointerp): refactor to async & support lorsa
* feat(autointerp): better parallelization with async * feat(database): show progress for database operations (add analysis & update feature) * feat(autointerp): better parallelization with async * feat(database): show progress for database operations (add analysis & update feature) * misc(ruff): fix ruff & typecheck errors * feat(autointerp): update ui to support autointerp wo verification * fix(format): fix pyright issues * fix(format): fix pyright issues * fix(misc): remove try-except logics for progress measure in autointerp * feat(autointerp): support max suppressing logits in autointerp * feature(autointerp): improved autointerp prompts and support lorsa autointerp with z pattern * fix(misc): ruff for autointerp * fix(misc): ruff for autointerp
1 parent 76956f7 commit cd0f71e

File tree

15 files changed

+1477
-1371
lines changed

15 files changed

+1477
-1371
lines changed

src/lm_saes/analysis/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from .direct_logit_attributor import DirectLogitAttributor
2-
from .feature_analyzer import FeatureAnalyzer
3-
from .feature_interpreter import (
1+
from lm_saes.analysis.autointerp import (
42
AutoInterpConfig,
53
ExplainerType,
64
FeatureInterpreter,
75
ScorerType,
86
TokenizedSample,
97
)
108

9+
from .direct_logit_attributor import DirectLogitAttributor
10+
from .feature_analyzer import FeatureAnalyzer
11+
1112
__all__ = [
1213
"FeatureAnalyzer",
1314
"FeatureInterpreter",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Prompt builders for auto-interpretation of SAE features.
2+
3+
This package contains modules for generating prompts used in the auto-interpretation
4+
process, organized by purpose:
5+
- explanation_prompts: Prompts for generating feature explanations
6+
- evaluation_prompts: Prompts for evaluating feature explanations
7+
"""
8+
9+
from .autointerp_base import (
10+
AutoInterpConfig,
11+
ExplainerType,
12+
ScorerType,
13+
Segment,
14+
TokenizedSample,
15+
process_token,
16+
)
17+
from .evaluation_prompts import (
18+
generate_detection_prompt,
19+
generate_fuzzing_prompt,
20+
)
21+
from .explanation_prompts import (
22+
generate_explanation_prompt,
23+
generate_explanation_prompt_neuronpedia,
24+
)
25+
from .feature_interpreter import (
26+
FeatureInterpreter,
27+
)
28+
29+
__all__ = [
30+
"generate_explanation_prompt",
31+
"generate_explanation_prompt_neuronpedia",
32+
"generate_detection_prompt",
33+
"generate_fuzzing_prompt",
34+
"FeatureInterpreter",
35+
"AutoInterpConfig",
36+
"ExplainerType",
37+
"ScorerType",
38+
"Segment",
39+
"TokenizedSample",
40+
"process_token",
41+
"FeatureInterpreter",
42+
]
43+
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Utility classes and functions for auto-interpretation of SAE features.
2+
3+
This module contains shared utilities used across the auto-interpretation system,
4+
including configuration, data structures, and helper functions.
5+
"""
6+
7+
from dataclasses import dataclass
8+
from enum import Enum
9+
from typing import Any, Optional
10+
11+
import torch
12+
from pydantic import Field
13+
14+
from lm_saes.config import BaseConfig
15+
from lm_saes.utils.logging import get_logger
16+
17+
logger = get_logger("analysis.autointerp_utils")
18+
19+
20+
def process_token(token: str) -> str:
21+
"""Process a token string by replacing special characters.
22+
23+
Args:
24+
token: The token string to process
25+
26+
Returns:
27+
Processed token string with special characters replaced
28+
"""
29+
return token.replace("\n", "⏎").replace("\t", "→").replace("\r", "↵")
30+
31+
32+
class ExplainerType(str, Enum):
33+
"""Types of LLM explainers supported."""
34+
35+
OPENAI = "openai"
36+
NEURONPEDIA = "neuronpedia"
37+
38+
39+
class ScorerType(str, Enum):
40+
"""Types of explanation scoring methods."""
41+
42+
DETECTION = "detection"
43+
FUZZING = "fuzzing"
44+
GENERATION = "generation"
45+
SIMULATION = "simulation"
46+
47+
48+
class AutoInterpConfig(BaseConfig):
49+
"""Configuration for automatic interpretation of SAE features."""
50+
51+
# LLM settings
52+
explainer_type: ExplainerType = ExplainerType.OPENAI
53+
openai_api_key: Optional[str] = None
54+
openai_model: str = "gpt-3.5-turbo"
55+
openai_base_url: Optional[str] = None
56+
openai_proxy: Optional[str] = None
57+
58+
# Activation retrieval settings
59+
n_activating_examples: int = 7
60+
n_non_activating_examples: int = 20
61+
activation_threshold: float = 0.7 # Threshold relative to max activation for highlighting tokens
62+
max_length: int = 50
63+
64+
# Scoring settings
65+
scorer_type: list[ScorerType] = Field(default_factory=lambda: [ScorerType.DETECTION, ScorerType.FUZZING])
66+
67+
# Detection settings
68+
detection_n_examples: int = 5 # Number of examples to show for detection
69+
70+
# Fuzzing settings
71+
fuzzing_n_examples: int = 5 # Number of examples to use for fuzzing
72+
fuzzing_decile_correct: int = 5 # Number of correctly marked examples per decile
73+
fuzzing_decile_incorrect: int = 2 # Number of incorrectly marked examples per decile
74+
75+
# Prompting settings
76+
include_cot: bool = True # Whether to use chain-of-thought prompting
77+
overwrite_existing: bool = False # Whether to overwrite existing interpretations
78+
79+
80+
@dataclass
81+
class Segment:
82+
"""A segment of text with its activation value."""
83+
84+
text: str
85+
"""The text of the segment."""
86+
87+
activation: float
88+
"""The activation value of the segment."""
89+
90+
def display(self, abs_threshold: float) -> str:
91+
"""Display the segment as a string with whether it's highlighted."""
92+
if self.activation > abs_threshold:
93+
return f"<<{self.text}>>"
94+
else:
95+
return self.text
96+
97+
def display_max(self, abs_threshold: float) -> str:
98+
"""Display the segment text if it exceeds the threshold."""
99+
if self.activation > abs_threshold:
100+
return f"{self.text}\n"
101+
else:
102+
return ""
103+
104+
@dataclass
105+
class ZPatternSegment:
106+
"""Data for a z pattern of a single token."""
107+
108+
contributing_indices: list[int]
109+
"""The indices of the contributing tokens in the sequence."""
110+
contributions: list[float]
111+
"""The contributions of the contributing tokens to the activation of the token."""
112+
max_contribution: float
113+
"""The maximum contribution of the contributing tokens to the activation of the token."""
114+
115+
@dataclass
116+
class TokenizedSample:
117+
"""A tokenized sample with its activation pattern organized into segments."""
118+
119+
segments: list[Segment]
120+
"""List of segments, each containing start/end positions and activation values."""
121+
122+
max_activation: float
123+
"""Global maximum activation value."""
124+
125+
z_pattern_data: dict[int, ZPatternSegment] | None = None
126+
127+
def display_highlighted(self, threshold: float = 0.7) -> str:
128+
"""Get the text with activating segments highlighted with << >> delimiters.
129+
130+
Args:
131+
threshold: Threshold relative to max activation for highlighting
132+
133+
Returns:
134+
Text with activating segments highlighted
135+
"""
136+
highlighted_text = "".join([seg.display(threshold * self.max_activation) for seg in self.segments])
137+
return highlighted_text
138+
139+
def display_plain(self) -> str:
140+
"""Get the text with all segments displayed."""
141+
return "".join([seg.text for seg in self.segments])
142+
143+
def display_max(self, threshold: float = 0.7) -> str:
144+
"""Get the text with max activating tokens and their context."""
145+
max_activation_text = ""
146+
hash_ = {}
147+
for i, seg in enumerate(self.segments):
148+
if seg.activation > threshold * self.max_activation:
149+
text = seg.text
150+
if text != "" and hash_.get(text, None) is None:
151+
hash_[text] = 1
152+
prev_text = "".join([self.segments[idx].text for idx in range(max(0, i - 3), i)])
153+
if self.z_pattern_data is not None and i in self.z_pattern_data:
154+
z_pattern_segment = self.z_pattern_data[i]
155+
k_prev_tokens = [f"({process_token(''.join([self.segments[idx].text for idx in range(max(0, j - 3), j)]))}) {process_token(self.segments[j].text)}"
156+
for j, contribution in zip(z_pattern_segment.contributing_indices, z_pattern_segment.contributions)
157+
if contribution > threshold * z_pattern_segment.max_contribution]
158+
contributing_text = f"[{'; '.join(k_prev_tokens)}] => "
159+
max_activation_text += contributing_text
160+
max_activation_text += f"({process_token(prev_text)}) {process_token(text)}\n"
161+
return max_activation_text
162+
163+
def display_next(self, threshold: float = 0.7) -> str:
164+
"""Get the token immediately after the max activating token."""
165+
next_activation_text = ""
166+
hash_ = {}
167+
Flag = False
168+
for seg in self.segments:
169+
if Flag:
170+
text = seg.text
171+
if text != "" and hash_.get(text, None) is None:
172+
hash_[text] = 1
173+
next_activation_text = process_token(text) + "\n"
174+
if seg.activation > threshold * self.max_activation:
175+
Flag = True
176+
else:
177+
Flag = False
178+
return next_activation_text
179+
180+
def add_z_pattern_data(
181+
self,
182+
z_pattern_indices: torch.Tensor,
183+
z_pattern_values: torch.Tensor,
184+
origins: list[dict[str, Any]]
185+
):
186+
self.z_pattern_data = {}
187+
activating_indices = z_pattern_indices[0].unique_consecutive()
188+
for i in activating_indices:
189+
if origins[i] is not None:
190+
contributing_indices_mask = z_pattern_indices[0] == i
191+
self.z_pattern_data[i.item()] = ZPatternSegment(
192+
contributing_indices=z_pattern_indices[1, contributing_indices_mask].tolist(),
193+
contributions=z_pattern_values[contributing_indices_mask].tolist(),
194+
max_contribution=z_pattern_values[contributing_indices_mask].max().item(),
195+
)
196+
197+
def has_z_pattern_data(self):
198+
return self.z_pattern_data is not None
199+
200+
@staticmethod
201+
def construct(
202+
text: str,
203+
activations: torch.Tensor,
204+
origins: list[dict[str, Any]],
205+
max_activation: float,
206+
) -> "TokenizedSample":
207+
"""Construct a TokenizedSample from text, activations, and origins.
208+
209+
Args:
210+
text: The full text string
211+
activations: Tensor of activation values
212+
origins: List of origin dictionaries with position information
213+
max_activation: Maximum activation value
214+
215+
Returns:
216+
A TokenizedSample instance
217+
"""
218+
positions: set[int] = set()
219+
for origin in origins:
220+
if origin and origin["key"] == "text":
221+
assert "range" in origin, f"Origin {origin} does not have a range"
222+
positions.add(origin["range"][0])
223+
positions.add(origin["range"][1])
224+
225+
sorted_positions = sorted(positions)
226+
227+
segments = []
228+
for i in range(len(sorted_positions) - 1):
229+
start, end = sorted_positions[i], sorted_positions[i + 1]
230+
try:
231+
segment_activation = max(
232+
act
233+
for origin, act in zip(origins, activations)
234+
if origin and origin["key"] == "text" and origin["range"][0] >= start and origin["range"][1] <= end
235+
)
236+
except Exception as e:
237+
logger.error(f"Error processing segment:\nstart={start}, end={end}, segment={text[start:end]}\n\n. Error: {e}")
238+
continue
239+
segments.append(Segment(text[start:end], segment_activation.item()))
240+
241+
return TokenizedSample(segments, max_activation)
242+
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Prompt builders for evaluating feature explanations.
2+
3+
This module contains functions for generating prompts used to evaluate SAE feature
4+
explanations, including detection and fuzzing evaluation methods.
5+
"""
6+
7+
from typing import Any
8+
9+
from lm_saes.analysis.autointerp.autointerp_base import AutoInterpConfig, TokenizedSample
10+
11+
12+
def generate_detection_prompt(
13+
cfg: AutoInterpConfig,
14+
explanation: dict[str, Any],
15+
examples: list[TokenizedSample],
16+
) -> tuple[str, str]:
17+
"""Generate a prompt for detection evaluation.
18+
19+
Args:
20+
cfg: Auto-interpretation configuration
21+
explanation: The explanation to evaluate
22+
examples: List of examples (mix of activating and non-activating)
23+
24+
Returns:
25+
Tuple of (system_prompt, user_prompt) strings
26+
"""
27+
system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. You will have to return a boolean list of the examples where you think the feature should activate at least once, on ANY of the words or substrings in the document, true if it does, false if it doesn't. Try not to be overly specific in your interpretation of the explanation."""
28+
system_prompt += """
29+
Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example.
30+
"""
31+
user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n"
32+
33+
for i, example in enumerate(examples, 1):
34+
user_prompt += f"Example {i}: {example.display_plain()}\n"
35+
36+
return system_prompt, user_prompt
37+
38+
39+
def generate_fuzzing_prompt(
40+
cfg: AutoInterpConfig,
41+
explanation: dict[str, Any],
42+
examples: list[tuple[TokenizedSample, bool]], # (sample, is_correctly_marked)
43+
) -> tuple[str, str]:
44+
"""Generate a prompt for fuzzing evaluation.
45+
46+
Args:
47+
cfg: Auto-interpretation configuration
48+
explanation: The explanation to evaluate
49+
examples: List of tuples (example, is_correctly_marked)
50+
51+
Returns:
52+
Tuple of (system_prompt, user_prompt) strings
53+
"""
54+
system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. In each example, text segments highlighted with << >> are presented as activating the feature as described in the explanation. You will have to return a boolean list of the examples where you think the highlighted parts CORRECTLY correspond to the explanation, true if they do, false if they don't. Try not to be overly specific in your interpretation of the explanation."""
55+
system_prompt += """
56+
Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example.
57+
"""
58+
user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n"
59+
60+
for i, (example, _) in enumerate(examples, 1):
61+
highlighted = example.display_highlighted(cfg.activation_threshold)
62+
user_prompt += f"Example {i}: {highlighted}\n"
63+
64+
return system_prompt, user_prompt
65+

0 commit comments

Comments
 (0)