|
| 1 | +"""Utility classes and functions for auto-interpretation of SAE features. |
| 2 | +
|
| 3 | +This module contains shared utilities used across the auto-interpretation system, |
| 4 | +including configuration, data structures, and helper functions. |
| 5 | +""" |
| 6 | + |
| 7 | +from dataclasses import dataclass |
| 8 | +from enum import Enum |
| 9 | +from typing import Any, Optional |
| 10 | + |
| 11 | +import torch |
| 12 | +from pydantic import Field |
| 13 | + |
| 14 | +from lm_saes.config import BaseConfig |
| 15 | +from lm_saes.utils.logging import get_logger |
| 16 | + |
| 17 | +logger = get_logger("analysis.autointerp_utils") |
| 18 | + |
| 19 | + |
| 20 | +def process_token(token: str) -> str: |
| 21 | + """Process a token string by replacing special characters. |
| 22 | +
|
| 23 | + Args: |
| 24 | + token: The token string to process |
| 25 | +
|
| 26 | + Returns: |
| 27 | + Processed token string with special characters replaced |
| 28 | + """ |
| 29 | + return token.replace("\n", "⏎").replace("\t", "→").replace("\r", "↵") |
| 30 | + |
| 31 | + |
| 32 | +class ExplainerType(str, Enum): |
| 33 | + """Types of LLM explainers supported.""" |
| 34 | + |
| 35 | + OPENAI = "openai" |
| 36 | + NEURONPEDIA = "neuronpedia" |
| 37 | + |
| 38 | + |
| 39 | +class ScorerType(str, Enum): |
| 40 | + """Types of explanation scoring methods.""" |
| 41 | + |
| 42 | + DETECTION = "detection" |
| 43 | + FUZZING = "fuzzing" |
| 44 | + GENERATION = "generation" |
| 45 | + SIMULATION = "simulation" |
| 46 | + |
| 47 | + |
| 48 | +class AutoInterpConfig(BaseConfig): |
| 49 | + """Configuration for automatic interpretation of SAE features.""" |
| 50 | + |
| 51 | + # LLM settings |
| 52 | + explainer_type: ExplainerType = ExplainerType.OPENAI |
| 53 | + openai_api_key: Optional[str] = None |
| 54 | + openai_model: str = "gpt-3.5-turbo" |
| 55 | + openai_base_url: Optional[str] = None |
| 56 | + openai_proxy: Optional[str] = None |
| 57 | + |
| 58 | + # Activation retrieval settings |
| 59 | + n_activating_examples: int = 7 |
| 60 | + n_non_activating_examples: int = 20 |
| 61 | + activation_threshold: float = 0.7 # Threshold relative to max activation for highlighting tokens |
| 62 | + max_length: int = 50 |
| 63 | + |
| 64 | + # Scoring settings |
| 65 | + scorer_type: list[ScorerType] = Field(default_factory=lambda: [ScorerType.DETECTION, ScorerType.FUZZING]) |
| 66 | + |
| 67 | + # Detection settings |
| 68 | + detection_n_examples: int = 5 # Number of examples to show for detection |
| 69 | + |
| 70 | + # Fuzzing settings |
| 71 | + fuzzing_n_examples: int = 5 # Number of examples to use for fuzzing |
| 72 | + fuzzing_decile_correct: int = 5 # Number of correctly marked examples per decile |
| 73 | + fuzzing_decile_incorrect: int = 2 # Number of incorrectly marked examples per decile |
| 74 | + |
| 75 | + # Prompting settings |
| 76 | + include_cot: bool = True # Whether to use chain-of-thought prompting |
| 77 | + overwrite_existing: bool = False # Whether to overwrite existing interpretations |
| 78 | + |
| 79 | + |
| 80 | +@dataclass |
| 81 | +class Segment: |
| 82 | + """A segment of text with its activation value.""" |
| 83 | + |
| 84 | + text: str |
| 85 | + """The text of the segment.""" |
| 86 | + |
| 87 | + activation: float |
| 88 | + """The activation value of the segment.""" |
| 89 | + |
| 90 | + def display(self, abs_threshold: float) -> str: |
| 91 | + """Display the segment as a string with whether it's highlighted.""" |
| 92 | + if self.activation > abs_threshold: |
| 93 | + return f"<<{self.text}>>" |
| 94 | + else: |
| 95 | + return self.text |
| 96 | + |
| 97 | + def display_max(self, abs_threshold: float) -> str: |
| 98 | + """Display the segment text if it exceeds the threshold.""" |
| 99 | + if self.activation > abs_threshold: |
| 100 | + return f"{self.text}\n" |
| 101 | + else: |
| 102 | + return "" |
| 103 | + |
| 104 | +@dataclass |
| 105 | +class ZPatternSegment: |
| 106 | + """Data for a z pattern of a single token.""" |
| 107 | + |
| 108 | + contributing_indices: list[int] |
| 109 | + """The indices of the contributing tokens in the sequence.""" |
| 110 | + contributions: list[float] |
| 111 | + """The contributions of the contributing tokens to the activation of the token.""" |
| 112 | + max_contribution: float |
| 113 | + """The maximum contribution of the contributing tokens to the activation of the token.""" |
| 114 | + |
| 115 | +@dataclass |
| 116 | +class TokenizedSample: |
| 117 | + """A tokenized sample with its activation pattern organized into segments.""" |
| 118 | + |
| 119 | + segments: list[Segment] |
| 120 | + """List of segments, each containing start/end positions and activation values.""" |
| 121 | + |
| 122 | + max_activation: float |
| 123 | + """Global maximum activation value.""" |
| 124 | + |
| 125 | + z_pattern_data: dict[int, ZPatternSegment] | None = None |
| 126 | + |
| 127 | + def display_highlighted(self, threshold: float = 0.7) -> str: |
| 128 | + """Get the text with activating segments highlighted with << >> delimiters. |
| 129 | +
|
| 130 | + Args: |
| 131 | + threshold: Threshold relative to max activation for highlighting |
| 132 | +
|
| 133 | + Returns: |
| 134 | + Text with activating segments highlighted |
| 135 | + """ |
| 136 | + highlighted_text = "".join([seg.display(threshold * self.max_activation) for seg in self.segments]) |
| 137 | + return highlighted_text |
| 138 | + |
| 139 | + def display_plain(self) -> str: |
| 140 | + """Get the text with all segments displayed.""" |
| 141 | + return "".join([seg.text for seg in self.segments]) |
| 142 | + |
| 143 | + def display_max(self, threshold: float = 0.7) -> str: |
| 144 | + """Get the text with max activating tokens and their context.""" |
| 145 | + max_activation_text = "" |
| 146 | + hash_ = {} |
| 147 | + for i, seg in enumerate(self.segments): |
| 148 | + if seg.activation > threshold * self.max_activation: |
| 149 | + text = seg.text |
| 150 | + if text != "" and hash_.get(text, None) is None: |
| 151 | + hash_[text] = 1 |
| 152 | + prev_text = "".join([self.segments[idx].text for idx in range(max(0, i - 3), i)]) |
| 153 | + if self.z_pattern_data is not None and i in self.z_pattern_data: |
| 154 | + z_pattern_segment = self.z_pattern_data[i] |
| 155 | + k_prev_tokens = [f"({process_token(''.join([self.segments[idx].text for idx in range(max(0, j - 3), j)]))}) {process_token(self.segments[j].text)}" |
| 156 | + for j, contribution in zip(z_pattern_segment.contributing_indices, z_pattern_segment.contributions) |
| 157 | + if contribution > threshold * z_pattern_segment.max_contribution] |
| 158 | + contributing_text = f"[{'; '.join(k_prev_tokens)}] => " |
| 159 | + max_activation_text += contributing_text |
| 160 | + max_activation_text += f"({process_token(prev_text)}) {process_token(text)}\n" |
| 161 | + return max_activation_text |
| 162 | + |
| 163 | + def display_next(self, threshold: float = 0.7) -> str: |
| 164 | + """Get the token immediately after the max activating token.""" |
| 165 | + next_activation_text = "" |
| 166 | + hash_ = {} |
| 167 | + Flag = False |
| 168 | + for seg in self.segments: |
| 169 | + if Flag: |
| 170 | + text = seg.text |
| 171 | + if text != "" and hash_.get(text, None) is None: |
| 172 | + hash_[text] = 1 |
| 173 | + next_activation_text = process_token(text) + "\n" |
| 174 | + if seg.activation > threshold * self.max_activation: |
| 175 | + Flag = True |
| 176 | + else: |
| 177 | + Flag = False |
| 178 | + return next_activation_text |
| 179 | + |
| 180 | + def add_z_pattern_data( |
| 181 | + self, |
| 182 | + z_pattern_indices: torch.Tensor, |
| 183 | + z_pattern_values: torch.Tensor, |
| 184 | + origins: list[dict[str, Any]] |
| 185 | + ): |
| 186 | + self.z_pattern_data = {} |
| 187 | + activating_indices = z_pattern_indices[0].unique_consecutive() |
| 188 | + for i in activating_indices: |
| 189 | + if origins[i] is not None: |
| 190 | + contributing_indices_mask = z_pattern_indices[0] == i |
| 191 | + self.z_pattern_data[i.item()] = ZPatternSegment( |
| 192 | + contributing_indices=z_pattern_indices[1, contributing_indices_mask].tolist(), |
| 193 | + contributions=z_pattern_values[contributing_indices_mask].tolist(), |
| 194 | + max_contribution=z_pattern_values[contributing_indices_mask].max().item(), |
| 195 | + ) |
| 196 | + |
| 197 | + def has_z_pattern_data(self): |
| 198 | + return self.z_pattern_data is not None |
| 199 | + |
| 200 | + @staticmethod |
| 201 | + def construct( |
| 202 | + text: str, |
| 203 | + activations: torch.Tensor, |
| 204 | + origins: list[dict[str, Any]], |
| 205 | + max_activation: float, |
| 206 | + ) -> "TokenizedSample": |
| 207 | + """Construct a TokenizedSample from text, activations, and origins. |
| 208 | +
|
| 209 | + Args: |
| 210 | + text: The full text string |
| 211 | + activations: Tensor of activation values |
| 212 | + origins: List of origin dictionaries with position information |
| 213 | + max_activation: Maximum activation value |
| 214 | +
|
| 215 | + Returns: |
| 216 | + A TokenizedSample instance |
| 217 | + """ |
| 218 | + positions: set[int] = set() |
| 219 | + for origin in origins: |
| 220 | + if origin and origin["key"] == "text": |
| 221 | + assert "range" in origin, f"Origin {origin} does not have a range" |
| 222 | + positions.add(origin["range"][0]) |
| 223 | + positions.add(origin["range"][1]) |
| 224 | + |
| 225 | + sorted_positions = sorted(positions) |
| 226 | + |
| 227 | + segments = [] |
| 228 | + for i in range(len(sorted_positions) - 1): |
| 229 | + start, end = sorted_positions[i], sorted_positions[i + 1] |
| 230 | + try: |
| 231 | + segment_activation = max( |
| 232 | + act |
| 233 | + for origin, act in zip(origins, activations) |
| 234 | + if origin and origin["key"] == "text" and origin["range"][0] >= start and origin["range"][1] <= end |
| 235 | + ) |
| 236 | + except Exception as e: |
| 237 | + logger.error(f"Error processing segment:\nstart={start}, end={end}, segment={text[start:end]}\n\n. Error: {e}") |
| 238 | + continue |
| 239 | + segments.append(Segment(text[start:end], segment_activation.item())) |
| 240 | + |
| 241 | + return TokenizedSample(segments, max_activation) |
| 242 | + |
0 commit comments