diff --git a/src/lm_saes/analysis/__init__.py b/src/lm_saes/analysis/__init__.py index 1b9bb2fc..eba3936a 100644 --- a/src/lm_saes/analysis/__init__.py +++ b/src/lm_saes/analysis/__init__.py @@ -1,6 +1,4 @@ -from .direct_logit_attributor import DirectLogitAttributor -from .feature_analyzer import FeatureAnalyzer -from .feature_interpreter import ( +from lm_saes.analysis.autointerp import ( AutoInterpConfig, ExplainerType, FeatureInterpreter, @@ -8,6 +6,9 @@ TokenizedSample, ) +from .direct_logit_attributor import DirectLogitAttributor +from .feature_analyzer import FeatureAnalyzer + __all__ = [ "FeatureAnalyzer", "FeatureInterpreter", diff --git a/src/lm_saes/analysis/autointerp/__init__.py b/src/lm_saes/analysis/autointerp/__init__.py new file mode 100644 index 00000000..e5dab36a --- /dev/null +++ b/src/lm_saes/analysis/autointerp/__init__.py @@ -0,0 +1,43 @@ +"""Prompt builders for auto-interpretation of SAE features. + +This package contains modules for generating prompts used in the auto-interpretation +process, organized by purpose: +- explanation_prompts: Prompts for generating feature explanations +- evaluation_prompts: Prompts for evaluating feature explanations +""" + +from .autointerp_base import ( + AutoInterpConfig, + ExplainerType, + ScorerType, + Segment, + TokenizedSample, + process_token, +) +from .evaluation_prompts import ( + generate_detection_prompt, + generate_fuzzing_prompt, +) +from .explanation_prompts import ( + generate_explanation_prompt, + generate_explanation_prompt_neuronpedia, +) +from .feature_interpreter import ( + FeatureInterpreter, +) + +__all__ = [ + "generate_explanation_prompt", + "generate_explanation_prompt_neuronpedia", + "generate_detection_prompt", + "generate_fuzzing_prompt", + "FeatureInterpreter", + "AutoInterpConfig", + "ExplainerType", + "ScorerType", + "Segment", + "TokenizedSample", + "process_token", + "FeatureInterpreter", +] + diff --git a/src/lm_saes/analysis/autointerp/autointerp_base.py b/src/lm_saes/analysis/autointerp/autointerp_base.py new file mode 100644 index 00000000..6cb72624 --- /dev/null +++ b/src/lm_saes/analysis/autointerp/autointerp_base.py @@ -0,0 +1,242 @@ +"""Utility classes and functions for auto-interpretation of SAE features. + +This module contains shared utilities used across the auto-interpretation system, +including configuration, data structures, and helper functions. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional + +import torch +from pydantic import Field + +from lm_saes.config import BaseConfig +from lm_saes.utils.logging import get_logger + +logger = get_logger("analysis.autointerp_utils") + + +def process_token(token: str) -> str: + """Process a token string by replacing special characters. + + Args: + token: The token string to process + + Returns: + Processed token string with special characters replaced + """ + return token.replace("\n", "⏎").replace("\t", "→").replace("\r", "↵") + + +class ExplainerType(str, Enum): + """Types of LLM explainers supported.""" + + OPENAI = "openai" + NEURONPEDIA = "neuronpedia" + + +class ScorerType(str, Enum): + """Types of explanation scoring methods.""" + + DETECTION = "detection" + FUZZING = "fuzzing" + GENERATION = "generation" + SIMULATION = "simulation" + + +class AutoInterpConfig(BaseConfig): + """Configuration for automatic interpretation of SAE features.""" + + # LLM settings + explainer_type: ExplainerType = ExplainerType.OPENAI + openai_api_key: Optional[str] = None + openai_model: str = "gpt-3.5-turbo" + openai_base_url: Optional[str] = None + openai_proxy: Optional[str] = None + + # Activation retrieval settings + n_activating_examples: int = 7 + n_non_activating_examples: int = 20 + activation_threshold: float = 0.7 # Threshold relative to max activation for highlighting tokens + max_length: int = 50 + + # Scoring settings + scorer_type: list[ScorerType] = Field(default_factory=lambda: [ScorerType.DETECTION, ScorerType.FUZZING]) + + # Detection settings + detection_n_examples: int = 5 # Number of examples to show for detection + + # Fuzzing settings + fuzzing_n_examples: int = 5 # Number of examples to use for fuzzing + fuzzing_decile_correct: int = 5 # Number of correctly marked examples per decile + fuzzing_decile_incorrect: int = 2 # Number of incorrectly marked examples per decile + + # Prompting settings + include_cot: bool = True # Whether to use chain-of-thought prompting + overwrite_existing: bool = False # Whether to overwrite existing interpretations + + +@dataclass +class Segment: + """A segment of text with its activation value.""" + + text: str + """The text of the segment.""" + + activation: float + """The activation value of the segment.""" + + def display(self, abs_threshold: float) -> str: + """Display the segment as a string with whether it's highlighted.""" + if self.activation > abs_threshold: + return f"<<{self.text}>>" + else: + return self.text + + def display_max(self, abs_threshold: float) -> str: + """Display the segment text if it exceeds the threshold.""" + if self.activation > abs_threshold: + return f"{self.text}\n" + else: + return "" + +@dataclass +class ZPatternSegment: + """Data for a z pattern of a single token.""" + + contributing_indices: list[int] + """The indices of the contributing tokens in the sequence.""" + contributions: list[float] + """The contributions of the contributing tokens to the activation of the token.""" + max_contribution: float + """The maximum contribution of the contributing tokens to the activation of the token.""" + +@dataclass +class TokenizedSample: + """A tokenized sample with its activation pattern organized into segments.""" + + segments: list[Segment] + """List of segments, each containing start/end positions and activation values.""" + + max_activation: float + """Global maximum activation value.""" + + z_pattern_data: dict[int, ZPatternSegment] | None = None + + def display_highlighted(self, threshold: float = 0.7) -> str: + """Get the text with activating segments highlighted with << >> delimiters. + + Args: + threshold: Threshold relative to max activation for highlighting + + Returns: + Text with activating segments highlighted + """ + highlighted_text = "".join([seg.display(threshold * self.max_activation) for seg in self.segments]) + return highlighted_text + + def display_plain(self) -> str: + """Get the text with all segments displayed.""" + return "".join([seg.text for seg in self.segments]) + + def display_max(self, threshold: float = 0.7) -> str: + """Get the text with max activating tokens and their context.""" + max_activation_text = "" + hash_ = {} + for i, seg in enumerate(self.segments): + if seg.activation > threshold * self.max_activation: + text = seg.text + if text != "" and hash_.get(text, None) is None: + hash_[text] = 1 + prev_text = "".join([self.segments[idx].text for idx in range(max(0, i - 3), i)]) + if self.z_pattern_data is not None and i in self.z_pattern_data: + z_pattern_segment = self.z_pattern_data[i] + k_prev_tokens = [f"({process_token(''.join([self.segments[idx].text for idx in range(max(0, j - 3), j)]))}) {process_token(self.segments[j].text)}" + for j, contribution in zip(z_pattern_segment.contributing_indices, z_pattern_segment.contributions) + if contribution > threshold * z_pattern_segment.max_contribution] + contributing_text = f"[{'; '.join(k_prev_tokens)}] => " + max_activation_text += contributing_text + max_activation_text += f"({process_token(prev_text)}) {process_token(text)}\n" + return max_activation_text + + def display_next(self, threshold: float = 0.7) -> str: + """Get the token immediately after the max activating token.""" + next_activation_text = "" + hash_ = {} + Flag = False + for seg in self.segments: + if Flag: + text = seg.text + if text != "" and hash_.get(text, None) is None: + hash_[text] = 1 + next_activation_text = process_token(text) + "\n" + if seg.activation > threshold * self.max_activation: + Flag = True + else: + Flag = False + return next_activation_text + + def add_z_pattern_data( + self, + z_pattern_indices: torch.Tensor, + z_pattern_values: torch.Tensor, + origins: list[dict[str, Any]] + ): + self.z_pattern_data = {} + activating_indices = z_pattern_indices[0].unique_consecutive() + for i in activating_indices: + if origins[i] is not None: + contributing_indices_mask = z_pattern_indices[0] == i + self.z_pattern_data[i.item()] = ZPatternSegment( + contributing_indices=z_pattern_indices[1, contributing_indices_mask].tolist(), + contributions=z_pattern_values[contributing_indices_mask].tolist(), + max_contribution=z_pattern_values[contributing_indices_mask].max().item(), + ) + + def has_z_pattern_data(self): + return self.z_pattern_data is not None + + @staticmethod + def construct( + text: str, + activations: torch.Tensor, + origins: list[dict[str, Any]], + max_activation: float, + ) -> "TokenizedSample": + """Construct a TokenizedSample from text, activations, and origins. + + Args: + text: The full text string + activations: Tensor of activation values + origins: List of origin dictionaries with position information + max_activation: Maximum activation value + + Returns: + A TokenizedSample instance + """ + positions: set[int] = set() + for origin in origins: + if origin and origin["key"] == "text": + assert "range" in origin, f"Origin {origin} does not have a range" + positions.add(origin["range"][0]) + positions.add(origin["range"][1]) + + sorted_positions = sorted(positions) + + segments = [] + for i in range(len(sorted_positions) - 1): + start, end = sorted_positions[i], sorted_positions[i + 1] + try: + segment_activation = max( + act + for origin, act in zip(origins, activations) + if origin and origin["key"] == "text" and origin["range"][0] >= start and origin["range"][1] <= end + ) + except Exception as e: + logger.error(f"Error processing segment:\nstart={start}, end={end}, segment={text[start:end]}\n\n. Error: {e}") + continue + segments.append(Segment(text[start:end], segment_activation.item())) + + return TokenizedSample(segments, max_activation) + diff --git a/src/lm_saes/analysis/autointerp/evaluation_prompts.py b/src/lm_saes/analysis/autointerp/evaluation_prompts.py new file mode 100644 index 00000000..e715d840 --- /dev/null +++ b/src/lm_saes/analysis/autointerp/evaluation_prompts.py @@ -0,0 +1,65 @@ +"""Prompt builders for evaluating feature explanations. + +This module contains functions for generating prompts used to evaluate SAE feature +explanations, including detection and fuzzing evaluation methods. +""" + +from typing import Any + +from lm_saes.analysis.autointerp.autointerp_base import AutoInterpConfig, TokenizedSample + + +def generate_detection_prompt( + cfg: AutoInterpConfig, + explanation: dict[str, Any], + examples: list[TokenizedSample], +) -> tuple[str, str]: + """Generate a prompt for detection evaluation. + + Args: + cfg: Auto-interpretation configuration + explanation: The explanation to evaluate + examples: List of examples (mix of activating and non-activating) + + Returns: + Tuple of (system_prompt, user_prompt) strings + """ + system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. You will have to return a boolean list of the examples where you think the feature should activate at least once, on ANY of the words or substrings in the document, true if it does, false if it doesn't. Try not to be overly specific in your interpretation of the explanation.""" + system_prompt += """ +Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example. +""" + user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n" + + for i, example in enumerate(examples, 1): + user_prompt += f"Example {i}: {example.display_plain()}\n" + + return system_prompt, user_prompt + + +def generate_fuzzing_prompt( + cfg: AutoInterpConfig, + explanation: dict[str, Any], + examples: list[tuple[TokenizedSample, bool]], # (sample, is_correctly_marked) +) -> tuple[str, str]: + """Generate a prompt for fuzzing evaluation. + + Args: + cfg: Auto-interpretation configuration + explanation: The explanation to evaluate + examples: List of tuples (example, is_correctly_marked) + + Returns: + Tuple of (system_prompt, user_prompt) strings + """ + system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. In each example, text segments highlighted with << >> are presented as activating the feature as described in the explanation. You will have to return a boolean list of the examples where you think the highlighted parts CORRECTLY correspond to the explanation, true if they do, false if they don't. Try not to be overly specific in your interpretation of the explanation.""" + system_prompt += """ +Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example. +""" + user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n" + + for i, (example, _) in enumerate(examples, 1): + highlighted = example.display_highlighted(cfg.activation_threshold) + user_prompt += f"Example {i}: {highlighted}\n" + + return system_prompt, user_prompt + diff --git a/src/lm_saes/analysis/autointerp/explanation_prompts.py b/src/lm_saes/analysis/autointerp/explanation_prompts.py new file mode 100644 index 00000000..19b233e8 --- /dev/null +++ b/src/lm_saes/analysis/autointerp/explanation_prompts.py @@ -0,0 +1,233 @@ +"""Prompt builders for generating feature explanations. + +This module contains functions for generating prompts used to explain SAE features, +including both neuronpedia-style and OpenAI-style explanation prompts. +""" + +from typing import Any + +from lm_saes.analysis.autointerp.autointerp_base import AutoInterpConfig, TokenizedSample, process_token + +NEURONPEDIA_SYSTEM_PROMPT_VANILLA = """You are explaining the behavior of a neuron in a neural network. Your final response should be a very concise explanation (1-6 words) that captures what the neuron detects or predicts by finding patterns in lists.\n\n +To determine the explanation, you are given four lists:\n\n +- MAX_ACTIVATING_TOKENS, which are the top activating tokens in the top activating texts. Each max activating token is shown with the previous 3 tokens in parentheses for context, e.g., "(Who am) I".\n +- TOKENS_AFTER_MAX_ACTIVATING_TOKEN, which are the tokens immediately after the max activating token.\n +- TOP_POSITIVE_LOGITS, which are the most likely words or tokens associated with this neuron.\n +- TOP_ACTIVATING_TEXTS, which are top activating texts.\n\n +You should look for a pattern by trying the following methods in order. You may go through each method even if you find a pattern with some method. You may sometimes need to combine different methods to give a better explanation.\n +Method 1: Look at MAX_ACTIVATING_TOKENS. If they share something specific in common, or are all the same token or a variation of the same token (like different cases or conjugations), respond with that token. + - Note that MAX_ACTIVATING_TOKENS are preceded by 3 tokens in parentheses as a short context. For example, "(Who am) I" is a max activating token on the word "I" with the context "(Who am)". + - These preceding tokens can be informative. If all MAX_ACTIVATING_TOKENS have the same or similar preceding tokens, respond with that preceding tokens, e.g. previous is X. +Method 2: Look at TOKENS_AFTER_MAX_ACTIVATING_TOKEN. Try to find a specific pattern or similarity in all the tokens. A common pattern is that they all start with the same letter. If you find a pattern (like \'s word\', \'the ending -ing\', \'number 8\'), respond with \'say [the pattern]\'. You can ignore uppercase/lowercase differences for this.\n +Method 3: Look at TOP_POSITIVE_LOGITS for similarities and describe it very briefly (1-3 words). These tokens are the most likely to be predicted with this neuron.\n +Method 4: Look at TOP_ACTIVATING_TEXTS and make a best guess by describing the broad theme or context, ignoring the max activating tokens.\n\n +Method 5: Look at TOP_NEGATIVE_LOGITS for similarities and describe it very briefly (1-3 words). These tokens are the most suppressed by this neuron. Use this method sparingly. Especially when the neuron is suppressing some rare tokens (e.g. very long tokens starting with newlines and from non-English-or-Chinese alphabets).\n +Rules:\n +- You can think carefully in your internal thinking process, but keep your returned explanation extremely concise (1-6 words, mostly 1-3 words).\n +- Do not add unnecessary phrases like "words related to", "concepts related to", or "variations of the word".\n +- Do not mention "tokens" or "patterns" in your explanation.\n +- The explanation should be specific. For example, "unique words" is not a specific enough pattern, nor is "foreign words".\n +- Remember to use the \'say [the pattern]\' when using Method 2 & 3 above (pattern found in TOKENS_AFTER_MAX_ACTIVATING_TOKEN and TOP_POSITIVE_LOGITS respectively).\n +- Remember to use the \'do not say [the pattern]\' when using Method 5 above (pattern found in TOP_NEGATIVE_LOGITS).\n +- If you absolutely cannot make any guesses, respond with "N/A".\n\n +Think carefully by going through each method number until you find one that helps you find an explanation for what this neuron is detecting or predicting. If a method does not help you find an explanation, briefly explain why it does not, then go on to the next method. Finally, end your thinking process with the method number you used, the reason for your explanation, and return the explanation in a brief manner.\n + +Examples: +{ +\n\nwas\nwatching\n\n\n\n\n\n(Who am) I\n(I really) enjoy\n\n\n\n\n\n\nwalking\nWA\nwaiting\nwas\nwe\nWHAM\nwish\nwin\nwake\nwhisper\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nShe was taking a nap when her phone started ringing.\nI enjoy watching movies with my family.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 fails: MAX_ACTIVATING_TOKENS (I, enjoy) are not similar tokens.\nMethod 2 succeeds: All TOKENS_AFTER_MAX_ACTIVATING_TOKEN have a pattern in common: they all start with "w".\nMethod 3 confirms: TOP_POSITIVE_LOGITS also show many words starting with "w" (walking, waiting, was, we, wish, win, wake, whisper), reinforcing the pattern found in Method 2.\nMethod 4: TOP_ACTIVATING_TEXTS don't provide additional clarity beyond what Methods 2 and 3 revealed.\nCombining Methods 2 and 3: The neuron detects tokens that start with "w" and predicts words starting with "w".\nExplanation: say "w" words\n\nsay "w" words +} + +{ +\n\nwarm\nthe\n\n\n\n\n\n\n(including you) and\n(matters .) And\n\n\n\n\n\n\nelephant\nguitar\nmountain\nbicycle\nocean\ntelescope\ncandle\numbrella\ntornado\nbutterfly\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nIt was a beautiful day outside with clear skies and warm sunshine.\nAnd the garden has roses and tulips and daisies and sunflowers blooming together.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 succeeds: All MAX_ACTIVATING_TOKENS are the word "and".\nMethod 2: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (warm, the) don't show a clear pattern related to "and".\nMethod 3: TOP_POSITIVE_LOGITS show diverse unrelated words (elephant, guitar, mountain, etc.), not reinforcing the "and" pattern.\nMethod 4: TOP_ACTIVATING_TEXTS show sentences with "and" but don't add information beyond Method 1.\nMethod 1 provides the clearest explanation: the neuron activates on the token "and".\nExplanation: and\n\nand +} + +{ +\n\nare\n,\n\n\n\n\n\n\n(from the) banana\n(from the) blueberries\n\n\n\n\n\n\napple\norange\npineapple\nwatermelon\nkiwi\npeach\npear\ngrape\ncherry\nplum\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nThe apple and banana are delicious foods that provide essential vitamins and nutrients.\nI enjoy eating fresh strawberries, blueberries, and mangoes during the summer months.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 succeeds: All MAX_ACTIVATING_TOKENS (banana, blueberries) are fruits.\nMethod 2: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (are, ,) don't show a clear pattern.\nMethod 3 confirms: TOP_POSITIVE_LOGITS show many fruits (apple, orange, pineapple, watermelon, kiwi, peach, pear, grape, cherry, plum), strongly reinforcing the pattern found in Method 1.\nMethod 4: TOP_ACTIVATING_TEXTS mention fruits but don't add information beyond Methods 1 and 3.\nCombining Methods 1 and 3: The neuron activates on fruit tokens and predicts fruit-related words.\nExplanation: fruits\n\n\nfruits +} + +{ +\n\nwas\nplaces\n\n\n\n\n\n\n(during the) war\n(in some) places\n\n\n\n\n\n\n4\nfour\nfourth\n4th\nIV\nFour\nFOUR\n~4\n4.0\nquartet\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nthe civil war was a major topic in history class .\n seasons of the year are winter , spring , summer , and fall or autumn in some places .\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 fails: MAX_ACTIVATING_TOKENS (war, places) are not all the same token and don't share a clear pattern.\nMethod 2 fails: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (was, places) are not all similar tokens and don't have a text pattern in common.\nMethod 3 succeeds: All TOP_POSITIVE_LOGITS are the number 4 (4, four, fourth, 4th, IV, Four, FOUR, ~4, 4.0, quartet).\nMethod 4: TOP_ACTIVATING_TEXTS mention "war" and "places" but don't clearly relate to the number 4 pattern found in Method 3.\nMethod 5: TOP_NEGATIVE_LOGITS don't show a clear pattern that would help explain the feature.\nMethod 3 provides the clearest explanation: the neuron predicts the number 4.\nExplanation: 4\n\n4 +} +""" + +NEURONPEDIA_SYSTEM_PROMPT_Z_PATTERN = """You are explaining the behavior of a neuron in a neural network. Your final response should be a very concise explanation (1-6 words) that captures what the neuron detects or predicts by finding patterns in lists.\n\n +To determine the explanation, you are given four lists:\n\n +- MAX_ACTIVATING_TOKENS, which are the top activating tokens in the top activating texts. Each max activating token is shown with the previous 3 tokens in parentheses for context, e.g., "(Who am) I".\n +- TOKENS_AFTER_MAX_ACTIVATING_TOKEN, which are the tokens immediately after the max activating token.\n +- TOP_POSITIVE_LOGITS, which are the most likely words or tokens associated with this neuron.\n +- TOP_ACTIVATING_TEXTS, which are top activating texts.\n\n +You should look for a pattern by trying the following methods in order. You may go through each method even if you find a pattern with some method. You may sometimes need to combine different methods to give a better explanation.\n +Method 1: Look at MAX_ACTIVATING_TOKENS. + - These neurons are likely to be activated by specific patterns in the text, such as the presence of certain words or phrases. This is much like attention heads attending to a certain concept in the text. + - The activating pattern is showed in the following format: [(previous tokens) token1] => [(previous tokens) token2]. For example, "[(from Dr.) Sam] => (is from) Dr." is a typical induction head pattern, which moves the attention from "Sam" to "Dr." in the next token. + - Source tokens and target tokens are preceded by 3 tokens in parentheses as a short context. For example, "(Who am) I" is a max activating token on the word "I" with the context "(Who am)". + - These preceding tokens can be informative. If all MAX_ACTIVATING_TOKENS have the same or similar preceding tokens, respond with that preceding tokens, e.g. previous is X. + - If source and target tokens are the same token, this typically means that the neuron is attending to its own token. In this case source token is often not informative. Try to find a pattern in the target token or try other methods. + - If this method succeeds, try to respond in the format of [source token] => [target token]. For instance, "[position] => [name]". +Method 2: Look at TOKENS_AFTER_MAX_ACTIVATING_TOKEN. Try to find a specific pattern or similarity in all the tokens. A common pattern is that they all start with the same letter. If you find a pattern (like \'s word\', \'the ending -ing\', \'number 8\'), respond with \'say [the pattern]\'. You can ignore uppercase/lowercase differences for this.\n +Method 3: Look at TOP_POSITIVE_LOGITS for similarities and describe it very briefly (1-3 words). These tokens are the most likely to be predicted with this neuron.\n +Method 4: Look at TOP_ACTIVATING_TEXTS and make a best guess by describing the broad theme or context, ignoring the max activating tokens.\n\n +Method 5: Look at TOP_NEGATIVE_LOGITS for similarities and describe it very briefly (1-3 words). These tokens are the most suppressed by this neuron. Use this method sparingly. Especially when the neuron is suppressing some rare tokens (e.g. very long tokens starting with newlines and from non-English-or-Chinese alphabets).\n +Rules:\n +- You can think carefully in your internal thinking process, but keep your returned explanation extremely concise (1-6 words, mostly 1-3 words).\n +- Do not add unnecessary phrases like "words related to", "concepts related to", or "variations of the word".\n +- Do not mention "tokens" or "patterns" in your explanation.\n +- The explanation should be specific. For example, "unique words" is not a specific enough pattern, nor is "foreign words".\n +- Remember to use the \'say [the pattern]\' when using Method 2 & 3 above (pattern found in TOKENS_AFTER_MAX_ACTIVATING_TOKEN and TOP_POSITIVE_LOGITS respectively).\n +- Remember to use the \'do not say [the pattern]\' when using Method 5 above (pattern found in TOP_NEGATIVE_LOGITS).\n +- If you absolutely cannot make any guesses, respond with "N/A".\n\n +Think carefully by going through each method number until you find one that helps you find an explanation for what this neuron is detecting or predicting. If a method does not help you find an explanation, briefly explain why it does not, then go on to the next method. Finally, end your thinking process with the method number you used, the reason for your explanation, and return the explanation in a brief manner.\n + +Examples: +{ +\n\nwas\nwatching\n\n\n\n\n\n[(Who am) I] => (Who am) I\n[(I really) enjoy] => (I really) enjoy\n\n\n\n\n\n\nwalking\nWA\nwaiting\nwas\nwe\nWHAM\nwish\nwin\nwake\nwhisper\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nShe was taking a nap when her phone started ringing.\nI enjoy watching movies with my family.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 fails: MAX_ACTIVATING_TOKENS show self-attention patterns [(Who am) I] => (Who am) I and [(I really) enjoy] => (I really) enjoy, but the source tokens (I, enjoy) are not similar.\nMethod 2 succeeds: All TOKENS_AFTER_MAX_ACTIVATING_TOKEN have a pattern in common: they all start with "w".\nMethod 3 confirms: TOP_POSITIVE_LOGITS also show many words starting with "w" (walking, waiting, was, we, wish, win, wake, whisper), reinforcing the pattern found in Method 2.\nMethod 4: TOP_ACTIVATING_TEXTS don't provide additional clarity beyond what Methods 2 and 3 revealed.\nCombining Methods 2 and 3: The neuron detects tokens that start with "w" and predicts words starting with "w".\nExplanation: say "w" words\n\nsay "w" words +} + +{ +\n\nManning\nChris\n\n\n\n\n\n\n[(from Dr.) Sam] => (is from) Dr.]\n[(is Prof.) Chris] => (he is) Prof.\n\n\n\n\n\n\nelephant\nguitar\nmountain\nbicycle\nocean\ntelescope\ncandle\numbrella\ntornado\nbutterfly\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nIt was a beautiful day outside with clear skies and warm sunshine.\nAnd the garden has roses and tulips and daisies and sunflowers blooming together.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 succeeds: Looking at the MAX_ACTIVATING_TOKENS, we can see that this neuron is attending to the position and name of the person. The pattern is [(from Dr.) Sam] => (is from) Dr. and [(is Prof.) Chris] => (he is) Prof., showing attention from name to position.\nMethod 2: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (Manning, Chris) are names, which aligns with the source tokens in Method 1.\nMethod 3: TOP_POSITIVE_LOGITS show diverse unrelated words, not reinforcing the position-name pattern.\nMethod 4: TOP_ACTIVATING_TEXTS don't provide additional clarity beyond Method 1.\nMethod 1 provides the clearest explanation: the neuron attends from name to position.\nExplanation: [name] => [position]\n\n[name] => [position] +} + +{ +\n\nare\n,\n\n\n\n\n\n\n[(from the) banana] => (from the) banana\n[(from the) blueberries] => (from the) blueberries\n\n\n\n\n\n\napple\norange\npineapple\nwatermelon\nkiwi\npeach\npear\ngrape\ncherry\nplum\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nThe apple and banana are delicious foods that provide essential vitamins and nutrients.\nI enjoy eating fresh strawberries, blueberries, and mangoes during the summer months.\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 succeeds: All MAX_ACTIVATING_TOKENS show self-attention on fruit tokens [(from the) banana] => (from the) banana and [(from the) blueberries] => (from the) blueberries. The tokens (banana, blueberries) are fruits.\nMethod 2: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (are, ,) don't show a clear pattern.\nMethod 3 confirms: TOP_POSITIVE_LOGITS show many fruits (apple, orange, pineapple, watermelon, kiwi, peach, pear, grape, cherry, plum), strongly reinforcing the pattern found in Method 1.\nMethod 4: TOP_ACTIVATING_TEXTS mention fruits but don't add information beyond Methods 1 and 3.\nCombining Methods 1 and 3: The neuron activates on fruit tokens and predicts fruit-related words.\nExplanation: fruits\n\n\nfruits +} + +{ +\n\nwas\nplaces\n\n\n\n\n\n\n[(during the) war] => (during the) war\n[(in some) places] => (in some) places\n\n\n\n\n\n\n4\nfour\nfourth\n4th\nIV\nFour\nFOUR\n~4\n4.0\nquartet\n\n\n\n\n\ndoes\napple\n\\n\nused\nsay\nvitamins\nneus\nautumn\nsun\nanation\n\n\n\n\nthe civil war was a major topic in history class .\n seasons of the year are winter , spring , summer , and fall or autumn in some places .\n\n\n\n\nExplanation of neuron behavior: \n +Method 1 fails: MAX_ACTIVATING_TOKENS show self-attention patterns [(during the) war] => (during the) war and [(in some) places] => (in some) places, but the tokens (war, places) are not all the same token and don't share a clear pattern.\nMethod 2 fails: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (was, places) are not all similar tokens and don't have a text pattern in common.\nMethod 3 succeeds: All TOP_POSITIVE_LOGITS are the number 4 (4, four, fourth, 4th, IV, Four, FOUR, ~4, 4.0, quartet).\nMethod 4: TOP_ACTIVATING_TEXTS mention "war" and "places" but don't clearly relate to the number 4 pattern found in Method 3.\nMethod 5: TOP_NEGATIVE_LOGITS don't show a clear pattern that would help explain the feature.\nMethod 3 provides the clearest explanation: the neuron predicts the number 4.\nExplanation: 4\n\n4 +} +""" + +def generate_explanation_prompt_neuronpedia( + cfg: AutoInterpConfig, + activating_examples: list[TokenizedSample], + top_logits: dict[str, list[dict[str, Any]]] | None = None, +) -> tuple[str, str]: + """Generate a prompt for explanation generation with neuronpedia. + + Args: + cfg: Auto-interpretation configuration + activating_examples: List of activating examples + top_logits: Optional top logits dictionary with 'top_positive' and 'top_negative' keys + + Returns: + Tuple of (system_prompt, user_prompt) strings + """ + system_prompt = NEURONPEDIA_SYSTEM_PROMPT_Z_PATTERN if activating_examples[0].has_z_pattern_data() else NEURONPEDIA_SYSTEM_PROMPT_VANILLA + examples_to_show = activating_examples[: cfg.n_activating_examples] + next_activating_tokens = "" + max_activating_tokens = "" + plain_activating_tokens = "" + logit_activating_tokens = "" + logit_suppressing_tokens = "" + + for i, example in enumerate(examples_to_show, 1): + next_activating_tokens = next_activating_tokens + example.display_next(cfg.activation_threshold) + max_activating_tokens = max_activating_tokens + example.display_max(cfg.activation_threshold) + plain_activating_tokens = plain_activating_tokens + process_token(example.display_plain()) + "\n" + + if top_logits is not None: + for text in top_logits['top_positive']: + logit_activating_tokens = logit_activating_tokens + process_token(text["token"]) + "\n" + for text in top_logits['top_negative']: + logit_suppressing_tokens = logit_suppressing_tokens + process_token(text["token"]) + "\n" + else: + logit_activating_tokens = next_activating_tokens + logit_suppressing_tokens = "none" + + user_prompt: str = f""" +\n\n{next_activating_tokens}\n\n\n\n\n\n{max_activating_tokens}\n\n\n\n\n\n{logit_activating_tokens}\n\n\n\n\n\n{logit_suppressing_tokens}\n\n\n\n\n\n{plain_activating_tokens}\n<\\TOP_ACTIVATING_TEXTS>\n\n\nExplanation of neuron behavior: \n +""" + # print('system_prompt', system_prompt) + # print('user_prompt', user_prompt) + return system_prompt, user_prompt + + +def generate_explanation_prompt( + cfg: AutoInterpConfig, + activating_examples: list[TokenizedSample], +) -> tuple[str, str]: + """Generate a prompt for explanation generation. + + Args: + cfg: Auto-interpretation configuration + activating_examples: List of activating examples + + Returns: + Tuple of (system_prompt, user_prompt) strings + """ + cot_prompt = "" + if cfg.include_cot: + cot_prompt += "\n\nTo explain this feature, please follow these steps:\n" + cot_prompt += "Step 1: List a couple activating and contextual tokens you find interesting. " + cot_prompt += "Search for patterns in these tokens, if there are any. Don't list more than 5 tokens.\n" + cot_prompt += "Step 2: Write down general shared features of the text examples.\n" + cot_prompt += "Step 3: Write a concise explanation of what this feature detects.\n" + + examples_prompt = """Some examples: + +{ + "steps": ["Activating token: <>. Contextual tokens: Who, ?. Pattern: <> is consistently activated, often found in sentences starting with interrogative words like 'Who' and ending with a question mark.", "Shared features include consistent activation on the word 'knows'. The surrounding text always forms a question. The questions do not seem to expect a literal answer, suggesting they are rhetorical.", "This feature activates on the word knows in rhetorical questions"], + "final_explanation": "The feature activates on the word 'knows' in rhetorical questions.", + "activation_consistency": 5, + "complexity": 4 +} + +{ + "steps": ["Activating tokens: <>, <>, <>, <>, <>. Pattern: All activating instances are on words that begin with the specific substring 'Ent'. The activation is on the 'Ent' portion itself.", "The shared feature across all examples is the presence of words starting with the capitalized substring 'Ent'. The feature appears to be case-sensitive and position-specific (start of the word). No other contextual or semantic patterns are observed."], + "final_explanation": "The feature activates on the substring 'Ent' at the start of words", + "activation_consistency": 5, + "complexity": 1 +} + +{ + "steps": ["Activating tokens: <>, <>, <>, <>, <>. Pattern: Activations highlight phrases and concepts central to economic discussions and government actions.","The examples consistently involve discussions of economic indicators, government spending, financial regulation, or international trade agreements. While most activations clearly relate to economic policies enacted or debated by governmental bodies, some activations might be on broader economic news or expert commentary where the direct link to a specific government policy is less explicit, or on related but not identical topics like corporate financial health in response to policy."], + "final_explanation": "The feature activates on text about government economic policy", + "activation_consistency": 3, + "complexity": 5 +} + +""" + system_prompt: str = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. The activating words in each document are indicated with << ... >>. We will give you a list of documents on which the feature activates, in order from most strongly activating to least strongly activating. + +Your task is to: + +First, Summarize the Activation: Look at the parts of the document the feature activates for and summarize in a single sentence what the feature is activating on. Try not to be overly specific in your explanation. Note that some features will activate only on specific words or substrings, but others will activate on most/all words in a sentence provided that sentence contains some particular concept. Your explanation should cover most or all activating words (for example, don't give an explanation which is specific to a single word if all words in a sentence cause the feature to activate). Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.{cot_prompt} + +Second, Assess Activation Consistency: Based on your summary and the provided examples, evaluate the consistency of the feature's activation. Return your assessment as a single integer from the following scale: + +5: Clear pattern with no deviating examples +4: Clear pattern with one or two deviating examples +3: Clear overall pattern but quite a few examples not fitting that pattern +2: Broad consistent theme but lacking structure +1: No discernible pattern + +Third, Assess Feature Complexity: Based on your summary and the nature of the activation, evaluate the complexity of the feature. Return your assessment as a single integer from the following scale: + +5: Rich feature firing on diverse contexts with an interesting unifying theme, e.g., "feelings of togetherness" +4: Feature relating to high-level semantic structure, e.g., "return statements in code" +3: Moderate complexity, such as a phrase, category, or tracking sentence structure, e.g., "website URLs" +2: Single word or token feature but including multiple languages or spelling, e.g., "mentions of dog" +1: Single token feature, e.g., "the token '('" + +Your output should be a JSON object that has the following fields: `steps`, `final_explanation`, `activation_consistency`, `complexity`. `steps` should be an array of strings with a length not exceeding 3, each representing a step in the chain-of-thought process. `final_explanation` should be a string in the form of 'This feature activates on... '. `activation_consistency` should be an integer between 1 and 5, representing the consistency of the feature. `complexity` should be an integer between 1 and 5, representing the complexity of the feature. + +{examples_prompt} +""" + + user_prompt = "The activating documents are given below:\n\n" + # Select a subset of examples to show + examples_to_show = activating_examples[: cfg.n_activating_examples] + + for i, example in enumerate(examples_to_show, 1): + highlighted = example.display_highlighted(cfg.activation_threshold) + user_prompt += f"Example {i}: {highlighted}\n\n" + + return system_prompt, user_prompt + diff --git a/src/lm_saes/analysis/autointerp/feature_interpreter.py b/src/lm_saes/analysis/autointerp/feature_interpreter.py new file mode 100644 index 00000000..441d3d81 --- /dev/null +++ b/src/lm_saes/analysis/autointerp/feature_interpreter.py @@ -0,0 +1,828 @@ +"""Auto-interpretation functionality for SAE features. + +This module provides tools for automatically interpreting and evaluating sparse autoencoder features +based on the EleutherAI auto-interp approach (https://blog.eleuther.ai/autointerp/). + +It includes: +1. Methods for prompting LLMs to generate explanations for features +2. Methods for evaluating explanations via different techniques: + - Detection: Having LLMs identify if examples contain a feature + - Fuzzing: Having LLMs identify correctly marked activating tokens +""" + +import asyncio +import random +import time +import traceback +from typing import Any, AsyncGenerator, Callable, Literal, Optional + +import json_repair +import numpy as np +import torch +from datasets import Dataset +from pydantic import BaseModel + +from lm_saes.analysis.autointerp import ( + AutoInterpConfig, + ExplainerType, + ScorerType, + Segment, + TokenizedSample, + generate_detection_prompt, + generate_explanation_prompt, + generate_explanation_prompt_neuronpedia, + generate_fuzzing_prompt, +) +from lm_saes.backend.language_model import LanguageModel +from lm_saes.database import FeatureAnalysis, FeatureRecord, MongoClient +from lm_saes.utils.logging import get_logger + +logger = get_logger("analysis.feature_interpreter") + + +class Step(BaseModel): + """A step in the chain-of-thought process.""" + + thought: str + """The thought of the step.""" + + # output: str + # """The output of the step.""" + + +Step_Schema = { + "type": "object", + "properties": { + "thought": {"type": "string", "description": "The thought of the step."}, + }, +} + + +class AutoInterpExplanation(BaseModel): + """The result of an auto-interpretation of a SAE feature.""" + + steps: list[Step] + """The steps of the chain-of-thought process.""" + + final_explanation: str + """The explanation of the feature.""" + + activation_consistency: Literal[1, 2, 3, 4, 5] + """The consistency of the feature.""" + + complexity: Literal[1, 2, 3, 4, 5] + """The complexity of the feature.""" + + +AutoInterpExplanation_Schema = { + "type": "object", + "properties": { + "steps": {"type": "array", "items": Step_Schema}, + "final_explanation": { + "type": "string", + "description": "The explanation of the feature, in the form of 'This feature activates on... '", + }, + "activation_consistency": { + "type": "integer", + "description": "The consistency of the feature, on a scale of 1 to 5.", + }, + "complexity": {"type": "integer", "description": "The complexity of the feature, on a scale of 1 to 5."}, + }, +} + + +class AutoInterpEvaluation(BaseModel): + """The result of an auto-interpretation of a SAE feature.""" + + steps: list[Step] + """The steps of the chain-of-thought process.""" + + evaluation_results: list[bool] + """The evaluation results for each example. Should be a list of YES/NO values.""" + + +AutoInterpEvaluation_Schema = { + "type": "object", + "properties": { + "steps": {"type": "array", "items": Step_Schema}, + "evaluation_results": { + "type": "array", + "items": {"type": "boolean"}, + "description": "The evaluation results for each example. Should be a list of True/False values.", + }, + }, +} + + +def generate_activating_examples( + feature: FeatureRecord, + model: LanguageModel, + datasets: Callable[[str, int, int], Dataset], + analysis: FeatureAnalysis, + n: int = 10, + max_length: int = 50, +) -> list[TokenizedSample]: + """Generate examples where a feature strongly activates using database records. + + Args: + feature: FeatureRecord to analyze + model: Language model to use + datasets: Callable to fetch datasets + analysis: FeatureAnalysis to use + n: Maximum number of examples to generate + max_length: Maximum length of examples to generate + + Returns: + List of TokenizedExample with high activation for the feature + """ + samples: list[TokenizedSample] = [] + + # Get examples from top activations + sampling = analysis.samplings[0] + feature_acts_ = torch.sparse_coo_tensor( + torch.tensor(sampling.feature_acts_indices), + torch.tensor(sampling.feature_acts_values), + (int(np.max(sampling.feature_acts_indices[0])), 2048), + ) + feature_acts_ = feature_acts_.to_dense() + + # Lorsa z pattern data so we can explain which tokens are contributing to the activation + # We want to operate in coo format since zpattern is 3-d + # z_pattern_indices: [n_samples, n_ctx, n_ctx] + # z_pattern_values: [n_samples, n_ctx, n_ctx] + if sampling.z_pattern_indices is not None: + assert sampling.z_pattern_values is not None, "Z pattern values are not available" + z_pattern_indices = torch.tensor(sampling.z_pattern_indices).int() + z_pattern_values = torch.tensor(sampling.z_pattern_values) + else: + z_pattern_indices = None + z_pattern_values = None + + for i, (dataset_name, shard_idx, n_shards, context_idx, feature_acts) in enumerate( + zip( + sampling.dataset_name, + sampling.shard_idx if sampling.shard_idx is not None else [0] * len(sampling.dataset_name), + sampling.n_shards if sampling.n_shards is not None else [1] * len(sampling.dataset_name), + sampling.context_idx, + feature_acts_, + ) + ): + dataset = datasets(dataset_name, shard_idx, n_shards) + data = dataset[int(context_idx)] + + # Process the sample using model's trace method + try: + origins = model.trace({k: [v] for k, v in data.items()})[0] + except Exception: + continue + + max_act_pos = torch.argmax(feature_acts).item() + + left_end = max(0, max_act_pos - max_length // 2) + right_end = min(len(origins), max_act_pos + max_length // 2) + + # Create TokenizedExample using the trace information + sample = TokenizedSample.construct( + text=data["text"], + activations=feature_acts[left_end:right_end], + origins=origins[left_end:right_end], + max_activation=analysis.max_feature_acts, + ) + # Find max contributing previous tokens for Lorsa. + # We want to operate in coo format since zpattern is 3-d. + if z_pattern_indices is not None and z_pattern_values is not None: + current_sequence_mask = z_pattern_indices[0] == i + current_z_pattern_indices = z_pattern_indices[1:, current_sequence_mask] + current_z_pattern_values = z_pattern_values[current_sequence_mask] + # Need to adjust indices since the text has been cropped + # and remove negative indices + out_of_right_end_indices = current_z_pattern_indices.lt(right_end).all(dim=0) + current_z_pattern_indices -= left_end + z_pattern_with_negative_indices = current_z_pattern_indices.ge(0).all(dim=0) + mask = out_of_right_end_indices * z_pattern_with_negative_indices + sample.add_z_pattern_data( + current_z_pattern_indices[:, mask], + current_z_pattern_values[mask], + origins, + ) + + samples.append(sample) + + if len(samples) >= n: + break + + return samples + + +def generate_non_activating_examples( + feature: FeatureRecord, + model: LanguageModel, + datasets: Callable[[str, int, int], Dataset], + analysis: FeatureAnalysis, + n: int = 10, + max_length: int = 50, +) -> list[TokenizedSample]: + """Generate examples where a feature doesn't activate much. + + Args: + feature: FeatureRecord to analyze + model: Language model to use + datasets: Callable to fetch datasets + analysis: FeatureAnalysis to use + n: Maximum number of examples to generate + max_length: Maximum length of examples to generate + + Returns: + List of TokenizedExample with low activation for the feature + """ + + samples: list[TokenizedSample] = [] + if n == 0: + return samples + error_prefix = f"Error processing non-activating examples of feature {feature.index}:" + + sampling_idx = -1 + for i in range(len(analysis.samplings)): + if analysis.samplings[i].name == "non_activating": + sampling_idx = i + break + if sampling_idx == -1: + return samples + sampling = analysis.samplings[sampling_idx] + + assert sampling.name == "non_activating", f"{error_prefix} Sampling {sampling.name} is not non_activating" + for i, (dataset_name, shard_idx, n_shards, context_idx, feature_acts_indices, feature_acts_values) in enumerate( + zip( + sampling.dataset_name, + sampling.shard_idx if sampling.shard_idx else [0] * len(sampling.dataset_name), + sampling.n_shards if sampling.n_shards else [1] * len(sampling.dataset_name), + sampling.context_idx, + # sampling.feature_acts, + sampling.feature_acts_indices, + sampling.feature_acts_values, + ) + ): + try: + feature_acts = torch.sparse_coo_tensor( + torch.Tensor(feature_acts_indices), + torch.Tensor(feature_acts_values), + (1024, sampling.context_idx.shape[0]), + ) + feature_acts = feature_acts.to_dense() + + dataset = datasets(dataset_name, shard_idx, n_shards) + data = dataset[context_idx] + + # Process the sample using model's trace method + # lock.acquire() + origins = model.trace({k: [v] for k, v in data.items()})[0] + # lock.release() + + # Create TokenizedExample using the trace information + sample = TokenizedSample.construct( + text=data["text"], + activations=feature_acts[:max_length], + origins=origins[:max_length], + max_activation=analysis.max_feature_acts, + ) + + samples.append(sample) + + except Exception as e: + logger.error(f"{error_prefix} {e}") + continue + + if len(samples) >= n: + break + + return samples + + +class FeatureInterpreter: + """A class for generating and evaluating explanations for SAE features.""" + + def __init__(self, cfg: AutoInterpConfig, mongo_client: MongoClient): + """Initialize the feature interpreter. + + Args: + cfg: Configuration for interpreter + mongo_client: Optional MongoDB client for retrieving data + """ + self.cfg = cfg + self.mongo_client = mongo_client + # Set up LLM client for explanation generation + self._setup_llm_clients() + + def _setup_llm_clients(self): + """Set up async OpenAI client for explanation generation and evaluation.""" + try: + import httpx + from openai import AsyncOpenAI + + # Set up async HTTP client with proxy if needed + http_client = None + if self.cfg.openai_proxy: + http_client = httpx.AsyncClient( + proxy=self.cfg.openai_proxy, + transport=httpx.AsyncHTTPTransport(local_address="0.0.0.0"), + ) + + self.explainer_client = AsyncOpenAI( + base_url=self.cfg.openai_base_url, + api_key=self.cfg.openai_api_key, + http_client=http_client, + ) + except ImportError: + raise ImportError("OpenAI package not installed. Please install it with `uv add openai`.") + + def get_feature_examples( + self, + feature: FeatureRecord, + model: LanguageModel, + datasets: Callable[[str, int, int], Dataset], + analysis_name: str = "default", + max_length: int = 50, + ) -> tuple[list[TokenizedSample], list[TokenizedSample]]: + """Get activating and non-activating examples for a feature.""" + analysis = next((a for a in feature.analyses if a.name == analysis_name), None) + if not analysis: + raise ValueError(f"Analysis {analysis_name} not found for feature {feature.index}") + + if analysis.max_feature_acts == 0: + raise ValueError(f"Feature {feature.index} has no activation. Skipping interpretation.") + + # Get examples from each sampling + activating_examples = generate_activating_examples( + feature=feature, + model=model, + datasets=datasets, + analysis=analysis, + n=self.cfg.n_activating_examples, + max_length=max_length, + ) + non_activating_examples = generate_non_activating_examples( + feature=feature, + model=model, + datasets=datasets, + analysis=analysis, + n=self.cfg.n_non_activating_examples, + max_length=max_length, + ) + return activating_examples, non_activating_examples + + + async def generate_explanation(self, activating_examples: list[TokenizedSample], top_logits: dict[str, list[dict[str, Any]]] | None = None) -> dict[str, Any]: + """Generate an explanation for a feature based on activating examples. + + Args: + activating_examples: List of examples where the feature activates + top_positive_logits: Top positive logits for the feature + Returns: + Dictionary with explanation and metadata + """ + if self.cfg.explainer_type is ExplainerType.OPENAI: + system_prompt, user_prompt = generate_explanation_prompt(self.cfg, activating_examples) + else: + system_prompt, user_prompt = generate_explanation_prompt_neuronpedia(self.cfg, activating_examples, top_logits) + print(system_prompt, user_prompt) + start_time = time.time() + + if self.cfg.explainer_type is ExplainerType.OPENAI: + response = await self.explainer_client.chat.completions.create( + model=self.cfg.openai_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + assert response.choices[0].message.content is not None, ( + f"No explanation returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" + ) + explanation = json_repair.loads(response.choices[0].message.content) + else: + response = await self.explainer_client.chat.completions.create( + model=self.cfg.openai_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + stream=False, + ) + + explanation = { + "final_explanation": response.choices[0].message.content, + "activation_consistency": 5, + "complexity": 5, + } + response_time = time.time() - start_time + return { + "user_prompt": user_prompt, + "system_prompt": system_prompt, + "response": explanation, + "time": response_time, + } + + + async def evaluate_explanation_detection( + self, + explanation: dict[str, Any], + activating_examples: list[TokenizedSample], + non_activating_examples: list[TokenizedSample], + ) -> dict[str, Any]: + """Evaluate an explanation using the detection method. + + Args: + explanation: The explanation to evaluate + activating_examples: Examples where the feature activates + non_activating_examples: Examples where the feature doesn't activate + + Returns: + Dictionary with evaluation results + """ + # Select a subset of examples + n_activating = min(self.cfg.detection_n_examples, len(activating_examples)) + n_non_activating = min(self.cfg.detection_n_examples, len(non_activating_examples)) + + test_activating = random.sample(activating_examples, n_activating) if n_activating > 0 else [] + test_non_activating = random.sample(non_activating_examples, n_non_activating) if n_non_activating > 0 else [] + + # Mix and shuffle examples + all_examples = test_activating + test_non_activating + if len(all_examples) < self.cfg.detection_n_examples: + return { + "method": "detection", + "prompt": "", + "response": "", + "ground_truth": [], + "predictions": [], + "metrics": { + "accuracy": 0, + "precision": 0, + "recall": 0, + "f1": 0, + "balanced_accuracy": 0, + }, + "passed": False, + "time": 0, + } + + random.shuffle(all_examples) + + # Ground truth for each example (1 for activating, 0 for non-activating) + ground_truth = [1 if ex in test_activating else 0 for ex in all_examples] + + # Generate prompt + system_prompt, user_prompt = generate_detection_prompt(self.cfg, explanation, all_examples) + + # Get response from OpenAI + start_time = time.time() + response = await self.explainer_client.chat.completions.create( + model=self.cfg.openai_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + assert response.choices[0].message.content is not None, ( + f"No detection response returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" + ) + detection_response: dict[str, Any] = json_repair.loads(response.choices[0].message.content) # type: ignore + predictions: list[bool] = detection_response["evaluation_results"] + response_time = time.time() - start_time + + # Pad predictions if needed + predictions = predictions[: len(ground_truth)] + if len(predictions) < len(ground_truth): + predictions.extend([False] * (len(ground_truth) - len(predictions))) + + # Calculate metrics + tp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 1 and pred == 1) + tn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 0 and pred == 0) + fp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 0 and pred == 1) + fn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 1 and pred == 0) + + accuracy = (tp + tn) / len(ground_truth) if ground_truth else 0 + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + balanced_accuracy = ((tp / (tp + fn) if (tp + fn) > 0 else 0) + (tn / (tn + fp) if (tn + fp) > 0 else 0)) / 2 + + return { + "method": "detection", + "prompt": system_prompt + "\n\n" + user_prompt, + "response": detection_response, + "ground_truth": ground_truth, + "predictions": predictions, + "metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "balanced_accuracy": balanced_accuracy, + }, + "passed": balanced_accuracy >= 0.7, # Arbitrary threshold for passing + "time": response_time, + } + + + def _create_incorrectly_marked_example(self, sample: TokenizedSample) -> TokenizedSample: + """Create an incorrectly marked version of an example. + + Args: + sample: The original sample + + Returns: + A copy of the sample with incorrect highlighting + """ + # Count how many tokens would be highlighted in the correct example + threshold = self.cfg.activation_threshold + n_highlighted = sum(1 for seg in sample.segments if seg.activation > threshold * sample.max_activation) + + def highlight_random_tokens(sample: TokenizedSample, n_highlighted: int) -> TokenizedSample: + non_activating_indices = [ + i for i, seg in enumerate(sample.segments) if seg.activation < threshold * sample.max_activation + ] + highlight_indices = random.sample(non_activating_indices, min(n_highlighted, len(non_activating_indices))) + segments = [ + Segment(seg.text, sample.max_activation if i in highlight_indices else 0) + for i, seg in enumerate(sample.segments) + ] + return TokenizedSample(segments, sample.max_activation) + + n_to_highlight = max(3, n_highlighted) # Highlight at least 3 tokens + return highlight_random_tokens(sample, n_to_highlight) + + async def evaluate_explanation_fuzzing( + self, explanation: dict[str, Any], activating_examples: list[TokenizedSample] + ) -> dict[str, Any]: + """Evaluate an explanation using the fuzzing method. + + Args: + explanation: The explanation to evaluate + activating_examples: Examples where the feature activates + + Returns: + Dictionary with evaluation results + """ + if len(activating_examples) < self.cfg.fuzzing_n_examples: + # Not enough examples, return empty result + return { + "method": "fuzzing", + "prompt": "", + "response": "", + "ground_truth": [], + "predictions": [], + "metrics": { + "accuracy": 0, + "precision": 0, + "recall": 0, + "f1": 0, + "balanced_accuracy": 0, + }, + "passed": False, + "time": 0, + } + + # Prepare examples: + # - Correctly marked examples (original) + # - Incorrectly marked examples (with wrong parts highlighted) + n_correct = self.cfg.fuzzing_decile_correct + n_incorrect = self.cfg.fuzzing_decile_incorrect + + # Get a sample of activating examples + sample_examples = random.sample(activating_examples, min(n_correct + n_incorrect, len(activating_examples))) + + # Split into correct and incorrect + correct_examples = sample_examples[:n_correct] + incorrect_candidates = sample_examples[n_correct:] + + # Create incorrectly marked versions + incorrect_examples = [self._create_incorrectly_marked_example(ex) for ex in incorrect_candidates] + + # Combine and mark with ground truth + examples_with_labels = [(ex, True) for ex in correct_examples] + [(ex, False) for ex in incorrect_examples] + + # Shuffle + random.shuffle(examples_with_labels) + + # Generate prompt + system_prompt, user_prompt = generate_fuzzing_prompt(self.cfg, explanation, examples_with_labels) + + # Get response from OpenAI + start_time = time.time() + response = await self.explainer_client.chat.completions.create( + model=self.cfg.openai_model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + response_format={"type": "json_object"}, + ) + assert response.choices[0].message.content is not None, ( + f"No fuzzing response returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" + ) + fuzzing_response: dict[str, Any] = json_repair.loads(response.choices[0].message.content) # type: ignore + predictions: list[bool] = fuzzing_response["evaluation_results"] + response_time = time.time() - start_time + # Pad predictions if needed + predictions = predictions[: len(examples_with_labels)] + if len(predictions) < len(examples_with_labels): + predictions.extend([False] * (len(examples_with_labels) - len(predictions))) + + # Extract ground truth + ground_truth = [is_correct for _, is_correct in examples_with_labels] + + # Calculate metrics + tp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is True and pred is True) + tn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is False and pred is False) + fp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is False and pred is True) + fn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is True and pred is False) + + accuracy = (tp + tn) / len(ground_truth) if ground_truth else 0 + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 + balanced_accuracy = ((tp / (tp + fn) if (tp + fn) > 0 else 0) + (tn / (tn + fp) if (tn + fp) > 0 else 0)) / 2 + + return { + "method": "fuzzing", + "prompt": user_prompt, + "response": fuzzing_response, + "ground_truth": ground_truth, + "predictions": predictions, + "metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "balanced_accuracy": balanced_accuracy, + }, + "passed": balanced_accuracy >= 0.7, # Arbitrary threshold for passing + "time": response_time, + } + + async def interpret_single_feature( + self, + activating_examples: list[TokenizedSample], + non_activating_examples: list[TokenizedSample], + top_logits: dict[str, list[dict[str, Any]]] | None = None, + ) -> dict[str, Any]: + start_time = time.time() + response_time = 0 + + # Generate explanation for the feature + explanation_result = await self.generate_explanation(activating_examples, top_logits) + explanation: dict[str, Any] = explanation_result["response"] + response_time += explanation_result["time"] + # Evaluate explanation + evaluation_results = [] + + if ScorerType.DETECTION in self.cfg.scorer_type: + detection_result = await self.evaluate_explanation_detection( + explanation, activating_examples, non_activating_examples + ) + evaluation_results.append(detection_result) + response_time += detection_result["time"] + + if ScorerType.FUZZING in self.cfg.scorer_type: + fuzzing_result = await self.evaluate_explanation_fuzzing(explanation, activating_examples) + evaluation_results.append(fuzzing_result) + response_time += fuzzing_result["time"] + + total_time = time.time() - start_time + + return { + "explanation": explanation["final_explanation"], + "complexity": explanation["complexity"], + "consistency": explanation["activation_consistency"], + "explanation_details": { + k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in explanation_result.items() + }, + "evaluations": [ + {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in eval_result.items()} + for eval_result in evaluation_results + ], + "passed": any(eval_result["passed"] for eval_result in evaluation_results), + "time": { + "total": total_time, + "response": response_time, + }, + } + + async def interpret_features( + self, + sae_name: str, + sae_series: str, + model: LanguageModel, + datasets: Callable[[str, int, int], Dataset], + analysis_name: str = "default", + feature_indices: Optional[list[int]] = None, + max_concurrent: int = 10, + progress_callback: Optional[Callable[[int, int, int], None]] = None, + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate and evaluate explanations for multiple features with async concurrency. + + Args: + sae_name: Name of the SAE + sae_series: Series of the SAE + model: Language model to use for generating activations + datasets: Callable to fetch datasets + analysis_name: Name of the analysis to use + feature_indices: Optional list of specific feature indices to interpret. If None, interprets all features. + max_concurrent: Maximum number of concurrent API requests + progress_callback: Optional callback function(completed, total, current_feature_index) for progress updates + + Yields: + Dictionary with interpretation results for each feature + """ + if feature_indices is None: + sae_record = self.mongo_client.get_sae(sae_name, sae_series) + assert sae_record is not None, f"SAE {sae_name} {sae_series} not found" + feature_indices = list(range(sae_record.cfg.d_sae)) + + total_features = len(feature_indices) + completed = 0 + skipped = 0 + failed = 0 + + logger.info(f"Starting interpretation of {total_features} features (max concurrent: {max_concurrent})") + + # Create semaphore to limit concurrent API requests + semaphore = asyncio.Semaphore(max_concurrent) + + async def interpret_with_semaphore(feature_index: int) -> tuple[Optional[dict[str, Any]], int, bool, bool]: + """Interpret a single feature with semaphore control. + + Returns: + Tuple of (result, feature_index, was_skipped, was_error) + """ + async with semaphore: + feature = self.mongo_client.get_feature(sae_name, sae_series, feature_index) + try: + if feature is not None and ( + self.cfg.overwrite_existing or feature.interpretation is None + ) and feature.analyses[0].act_times > 0: + activating_examples, non_activating_examples = self.get_feature_examples( + feature=feature, + model=model, + datasets=datasets, + analysis_name=analysis_name, + max_length=self.cfg.max_length, + ) + result = await self.interpret_single_feature( + activating_examples=activating_examples, + non_activating_examples=non_activating_examples, + top_logits=feature.logits, + ) + return ( + { + "feature_index": feature.index, + "sae_name": sae_name, + "sae_series": sae_series, + } | result, + feature_index, + False, + False, + ) + else: + # Feature already has interpretation or doesn't exist + return None, feature_index, True, False + except Exception as e: + logger.error(f"Error interpreting feature {feature_index}:\n{e}\n{traceback.format_exc()}") + return None, feature_index, False, True + + # Process features concurrently + tasks = [interpret_with_semaphore(feature_index) for feature_index in feature_indices] + + # Yield results as they complete + for coro in asyncio.as_completed(tasks): + result, feature_index, was_skipped, was_error = await coro + + if was_skipped: + skipped += 1 + logger.debug(f"Feature {feature_index} skipped (already has interpretation)") + elif was_error: + failed += 1 + elif result is not None: + completed += 1 + logger.info( + f"Completed feature {feature_index} ({completed}/{total_features} completed, " + f"{skipped} skipped, {failed} failed)" + ) + yield result + else: + skipped += 1 + + # Calculate total processed (completed + skipped + failed) + total_processed = completed + skipped + failed + + # Call progress callback if provided + if progress_callback is not None: + progress_callback(total_processed, total_features, feature_index) + + logger.info( + f"Interpretation complete: {completed} completed, {skipped} skipped, {failed} failed out of {total_features} total" + ) + diff --git a/src/lm_saes/analysis/feature_interpreter.py b/src/lm_saes/analysis/feature_interpreter.py deleted file mode 100644 index bb066100..00000000 --- a/src/lm_saes/analysis/feature_interpreter.py +++ /dev/null @@ -1,1113 +0,0 @@ -"""Auto-interpretation functionality for SAE features. - -This module provides tools for automatically interpreting and evaluating sparse autoencoder features -based on the EleutherAI auto-interp approach (https://blog.eleuther.ai/autointerp/). - -It includes: -1. Methods for prompting LLMs to generate explanations for features -2. Methods for evaluating explanations via different techniques: - - Detection: Having LLMs identify if examples contain a feature - - Fuzzing: Having LLMs identify correctly marked activating tokens -""" - -import random -import time -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Generator, Literal, Optional - -import json_repair -import numpy as np -import torch -from datasets import Dataset -from pydantic import BaseModel, Field - -from lm_saes.backend.language_model import LanguageModel -from lm_saes.config import BaseConfig -from lm_saes.database import FeatureAnalysis, FeatureRecord, MongoClient -from lm_saes.utils.logging import get_logger - -logger = get_logger("analysis.feature_interpreter") - - -class ExplainerType(str, Enum): - """Types of LLM explainers supported.""" - - OPENAI = "openai" - NEURONPEDIA = "neuronpedia" - - -class ScorerType(str, Enum): - """Types of explanation scoring methods.""" - - DETECTION = "detection" - FUZZING = "fuzzing" - GENERATION = "generation" - SIMULATION = "simulation" - - -class Step(BaseModel): - """A step in the chain-of-thought process.""" - - thought: str - """The thought of the step.""" - - # output: str - # """The output of the step.""" - - -Step_Schema = { - "type": "object", - "properties": { - "thought": {"type": "string", "description": "The thought of the step."}, - }, -} - - -class AutoInterpExplanation(BaseModel): - """The result of an auto-interpretation of a SAE feature.""" - - steps: list[Step] - """The steps of the chain-of-thought process.""" - - final_explanation: str - """The explanation of the feature.""" - - activation_consistency: Literal[1, 2, 3, 4, 5] - """The consistency of the feature.""" - - complexity: Literal[1, 2, 3, 4, 5] - """The complexity of the feature.""" - - -AutoInterpExplanation_Schema = { - "type": "object", - "properties": { - "steps": {"type": "array", "items": Step_Schema}, - "final_explanation": { - "type": "string", - "description": "The explanation of the feature, in the form of 'This feature activates on... '", - }, - "activation_consistency": { - "type": "integer", - "description": "The consistency of the feature, on a scale of 1 to 5.", - }, - "complexity": {"type": "integer", "description": "The complexity of the feature, on a scale of 1 to 5."}, - }, -} - - -class AutoInterpEvaluation(BaseModel): - """The result of an auto-interpretation of a SAE feature.""" - - steps: list[Step] - """The steps of the chain-of-thought process.""" - - evaluation_results: list[bool] - """The evaluation results for each example. Should be a list of YES/NO values.""" - - -AutoInterpEvaluation_Schema = { - "type": "object", - "properties": { - "steps": {"type": "array", "items": Step_Schema}, - "evaluation_results": { - "type": "array", - "items": {"type": "boolean"}, - "description": "The evaluation results for each example. Should be a list of True/False values.", - }, - }, -} - - -class AutoInterpConfig(BaseConfig): - """Configuration for automatic interpretation of SAE features.""" - - # LLM settings - explainer_type: ExplainerType = ExplainerType.OPENAI - openai_api_key: Optional[str] = None - openai_model: str = "gpt-3.5-turbo" - openai_base_url: Optional[str] = None - openai_proxy: Optional[str] = None - - # Activation retrieval settings - n_activating_examples: int = 7 - n_non_activating_examples: int = 20 - activation_threshold: float = 0.7 # Threshold relative to max activation for highlighting tokens - max_length: int = 50 - - # Scoring settings - scorer_type: list[ScorerType] = Field(default_factory=lambda: [ScorerType.DETECTION, ScorerType.FUZZING]) - - # Detection settings - detection_n_examples: int = 5 # Number of examples to show for detection - - # Fuzzing settings - fuzzing_n_examples: int = 5 # Number of examples to use for fuzzing - fuzzing_decile_correct: int = 5 # Number of correctly marked examples per decile - fuzzing_decile_incorrect: int = 2 # Number of incorrectly marked examples per decile - - # Prompting settings - include_cot: bool = True # Whether to use chain-of-thought prompting - - -@dataclass -class Segment: - """A segment of text with its activation value.""" - - text: str - """The text of the segment.""" - - activation: float - """The activation value of the segment.""" - - def display(self, abs_threshold: float) -> str: - """Display the segment as a string with whether it's highlighted.""" - if self.activation > abs_threshold: - return f"<<{self.text}>>" - else: - return self.text - - def display_max(self, abs_threshold: float) -> str: - if self.activation > abs_threshold: - return f"{self.text}\n" - else: - return "" - - -@dataclass -class TokenizedSample: - """A tokenized sample with its activation pattern organized into segments.""" - - segments: list[Segment] - """List of segments, each containing start/end positions and activation values.""" - - max_activation: float - """Global maximum activation value.""" - - def display_highlighted(self, threshold: float = 0.7) -> str: - """Get the text with activating segments highlighted with << >> delimiters. - - Args: - threshold: Threshold relative to max activation for highlighting - - Returns: - Text with activating segments highlighted - """ - highlighted_text = "".join([seg.display(threshold * self.max_activation) for seg in self.segments]) - return highlighted_text - - def display_plain(self) -> str: - """Get the text with all segments displayed.""" - return "".join([seg.text for seg in self.segments]) - - def display_max(self, threshold: float = 0.7) -> str: - # max_activation_text = "".join([seg.display_max(threshold * self.max_activation) for seg in self.segments]) - max_activation_text = "" - hash_ = {} - for seg in self.segments: - if seg.activation > threshold * self.max_activation: - text = seg.text - if text != "" and hash_.get(text, None) is None: - hash_[text] = 1 - max_activation_text = text + "\n" - return max_activation_text - - def display_next(self, threshold: float = 0.7) -> str: - # max_activation_text = "".join([seg.display_max(threshold * self.max_activation) for seg in self.segments]) - next_activation_text = "" - hash_ = {} - Flag = False - for seg in self.segments: - if Flag: - text = seg.text - if text != "" and hash_.get(text, None) is None: - hash_[text] = 1 - next_activation_text = text + "\n" - if seg.activation > threshold * self.max_activation: - Flag = True - else: - Flag = False - return next_activation_text - - @staticmethod - def construct( - text: str, - activations: list[float], - origins: list[dict[str, Any]], - max_activation: float, - ) -> "TokenizedSample": - """Construct a TokenizedSample from text, activations, and origins.""" - positions: set[int] = set() - for origin in origins: - if origin and origin["key"] == "text": - assert "range" in origin, f"Origin {origin} does not have a range" - positions.add(origin["range"][0]) - positions.add(origin["range"][1]) - - sorted_positions = sorted(positions) - - segments = [] - for i in range(len(sorted_positions) - 1): - start, end = sorted_positions[i], sorted_positions[i + 1] - # try: - segment_activation = max( - act - for origin, act in zip(origins, activations) - if origin and origin["key"] == "text" and origin["range"][0] >= start and origin["range"][1] <= end - ) - # except Exception: - # logger.error(f"Error processing segment:\nstart={start}, end={end}, segment={text[start:end]}\n\n") - # continue - segments.append(Segment(text[start:end], segment_activation)) - - return TokenizedSample(segments, max_activation) - - -def generate_activating_examples( - feature: FeatureRecord, - model: LanguageModel, - datasets: Callable[[str, int, int], Dataset], - analysis: FeatureAnalysis, - n: int = 10, - max_length: int = 50, -) -> list[TokenizedSample]: - """Generate examples where a feature strongly activates using database records. - - Args: - feature: FeatureRecord to analyze - model: Language model to use - datasets: Callable to fetch datasets - analysis: FeatureAnalysis to use - n: Maximum number of examples to generate - max_length: Maximum length of examples to generate - - Returns: - List of TokenizedExample with high activation for the feature - """ - samples: list[TokenizedSample] = [] - error_prefix = f"Error processing activating examples of feature {feature.index}: " - - # Get examples from each sampling - sampling = analysis.samplings[0] - # print(f'{sampling.context_idx.shape=}') - # print(f'{sampling.feature_acts_values.shape=} {sampling.feature_acts_indices=}') - # feature_acts_ = torch.sparse_coo_tensor(torch.Tensor(sampling.feature_acts_indices), torch.Tensor(sampling.feature_acts_values), (1024, sampling.context_idx.shape[0])) - feature_acts_ = torch.sparse_coo_tensor( - torch.Tensor(sampling.feature_acts_indices), - torch.Tensor(sampling.feature_acts_values), - (int(np.max(sampling.feature_acts_indices[0])), 2048), - ) - feature_acts_ = feature_acts_.to_dense() - - for i, (dataset_name, shard_idx, n_shards, context_idx, feature_acts) in enumerate( - zip( - sampling.dataset_name, - sampling.shard_idx if sampling.shard_idx is not None else [0] * len(sampling.dataset_name), - sampling.n_shards if sampling.n_shards is not None else [1] * len(sampling.dataset_name), - sampling.context_idx, - feature_acts_, - ) - ): - try: - dataset = datasets(dataset_name, shard_idx, n_shards) - # context_idx = context_idx.astype(int) - data = dataset[int(context_idx)] - - # Process the sample using model's trace method - origins = model.trace({k: [v] for k, v in data.items()})[0] - - max_act_pos = torch.argmax(torch.tensor(feature_acts)).item() - # print(f'{max_act_pos=}') - # print(f'{feature_acts=}') - - left_end = max(0, max_act_pos - max_length // 2) - right_end = min(len(origins), max_act_pos + max_length // 2) - - # Create TokenizedExample using the trace information - sample = TokenizedSample.construct( - text=data["text"], - activations=feature_acts[left_end:right_end].tolist(), - origins=origins[left_end:right_end], - max_activation=analysis.max_feature_acts, - ) - # print('run activating') - samples.append(sample) - - except Exception as e: - logger.error(f"{error_prefix} {e}") - continue - - if len(samples) >= n: - break - - return samples - - -def generate_non_activating_examples( - feature: FeatureRecord, - model: LanguageModel, - datasets: Callable[[str, int, int], Dataset], - analysis: FeatureAnalysis, - n: int = 10, - max_length: int = 50, -) -> list[TokenizedSample]: - """Generate examples where a feature doesn't activate much. - - Args: - feature: FeatureRecord to analyze - model: Language model to use - datasets: Callable to fetch datasets - analysis: FeatureAnalysis to use - n: Maximum number of examples to generate - max_length: Maximum length of examples to generate - - Returns: - List of TokenizedExample with low activation for the feature - """ - - samples: list[TokenizedSample] = [] - error_prefix = f"Error processing non-activating examples of feature {feature.index}:" - - sampling_idx = -1 - for i in range(len(analysis.samplings)): - if analysis.samplings[i].name == "non_activating": - sampling_idx = i - break - if sampling_idx == -1: - return samples - sampling = analysis.samplings[sampling_idx] - # print(f'{len(analysis.samplings)=}') - # for sample in analysis.samplings: - # print(sample.name) - assert sampling.name == "non_activating", f"{error_prefix} Sampling {sampling.name} is not non_activating" - for i, (dataset_name, shard_idx, n_shards, context_idx, feature_acts_indices, feature_acts_values) in enumerate( - zip( - sampling.dataset_name, - sampling.shard_idx if sampling.shard_idx else [0] * len(sampling.dataset_name), - sampling.n_shards if sampling.n_shards else [1] * len(sampling.dataset_name), - sampling.context_idx, - # sampling.feature_acts, - sampling.feature_acts_indices, - sampling.feature_acts_values, - ) - ): - try: - feature_acts = torch.sparse_coo_tensor( - torch.Tensor(feature_acts_indices), - torch.Tensor(feature_acts_values), - (1024, sampling.context_idx.shape[0]), - ) - feature_acts = feature_acts.to_dense() - - dataset = datasets(dataset_name, shard_idx, n_shards) - data = dataset[context_idx] - - # Process the sample using model's trace method - # lock.acquire() - origins = model.trace({k: [v] for k, v in data.items()})[0] - # lock.release() - - # Create TokenizedExample using the trace information - sample = TokenizedSample.construct( - text=data["text"], - activations=feature_acts[:max_length].tolist(), - origins=origins[:max_length], - max_activation=analysis.max_feature_acts, - ) - - samples.append(sample) - - except Exception as e: - logger.error(f"{error_prefix} {e}") - continue - - if len(samples) >= n: - break - - return samples - - -class FeatureInterpreter: - """A class for generating and evaluating explanations for SAE features.""" - - def __init__(self, cfg: AutoInterpConfig, mongo_client: MongoClient): - """Initialize the feature interpreter. - - Args: - cfg: Configuration for interpreter - mongo_client: Optional MongoDB client for retrieving data - """ - self.cfg = cfg - self.mongo_client = mongo_client - self.logits = None - # Set up LLM client for explanation generation - self._setup_llm_clients() - - def _setup_llm_clients(self): - """Set up OpenAI client for explanation generation and evaluation.""" - try: - import httpx - import openai - from openai import DefaultHttpxClient - - self.explainer_client = openai.Client( - base_url=self.cfg.openai_base_url, - api_key=self.cfg.openai_api_key, - http_client=DefaultHttpxClient( - proxy=self.cfg.openai_proxy, - transport=httpx.HTTPTransport(local_address="0.0.0.0"), - ) - if self.cfg.openai_proxy - else None, - ) - except ImportError: - raise ImportError("OpenAI package not installed. Please install it with `uv add openai`.") - - def get_feature_examples( - self, - feature: FeatureRecord, - model: LanguageModel, - datasets: Callable[[str, int, int], Dataset], - analysis_name: str = "default", - max_length: int = 50, - ) -> tuple[list[TokenizedSample], list[TokenizedSample]]: - """Get activating and non-activating examples for a feature.""" - analysis = next((a for a in feature.analyses if a.name == analysis_name), None) - if not analysis: - raise ValueError(f"Analysis {analysis_name} not found for feature {feature.index}") - - if analysis.max_feature_acts == 0: - raise ValueError(f"Feature {feature.index} has no activation. Skipping interpretation.") - - # Get examples from each sampling - activating_examples = generate_activating_examples( - feature=feature, - model=model, - datasets=datasets, - analysis=analysis, - max_length=max_length, - ) - non_activating_examples = generate_non_activating_examples( - feature=feature, - model=model, - datasets=datasets, - analysis=analysis, - max_length=max_length, - ) - # self.logits = None - return activating_examples, non_activating_examples - - def _generate_explanation_prompt_neuronpedia(self, activating_examples: list[TokenizedSample]) -> tuple[str, str]: - """Generate a prompt for explanation generation with neuronpedia. - - Args: - activating_examples: List of activating examples - - Returns: - Prompt string for the LLM - """ - system_prompt = """You are explaining the behavior of a neuron in a neural network. Your response should be a very concise explanation (1-6 words) that captures what the neuron detects or predicts by finding patterns in lists.\n\n -To determine the explanation, you are given four lists:\n\n -- MAX_ACTIVATING_TOKENS, which are the top activating tokens in the top activating texts.\n -- TOKENS_AFTER_MAX_ACTIVATING_TOKEN, which are the tokens immediately after the max activating token.\n -- TOP_POSITIVE_LOGITS, which are the most likely words or tokens associated with this neuron.\n -- TOP_ACTIVATING_TEXTS, which are top activating texts.\n\n -You should look for a pattern by trying the following methods in order. Once you find a pattern, stop and return that pattern. Do not proceed to the later methods.\n -Method 1: Look at MAX_ACTIVATING_TOKENS. If they share something specific in common, or are all the same token or a variation of the same token (like different cases or conjugations), respond with that token.\n -Method 2: Look at TOKENS_AFTER_MAX_ACTIVATING_TOKEN. Try to find a specific pattern or similarity in all the tokens. A common pattern is that they all start with the same letter. If you find a pattern (like \'s word\', \'the ending -ing\', \'number 8\'), respond with \'say [the pattern]\'. You can ignore uppercase/lowercase differences for this.\n -Method 3: Look at TOP_POSITIVE_LOGITS for similarities and describe it very briefly (1-3 words).\n -Method 4: Look at TOP_ACTIVATING_TEXTS and make a best guess by describing the broad theme or context, ignoring the max activating tokens.\n\n -Rules:\n -- Keep your explanation extremely concise (1-6 words, mostly 1-3 words).\n -- Do not add unnecessary phrases like "words related to", "concepts related to", or "variations of the word".\n -- Do not mention "tokens" or "patterns" in your explanation.\n -- The explanation should be specific. For example, "unique words" is not a specific enough pattern, nor is "foreign words".\n -- Remember to use the \'say [the pattern]\' when using Method 2 above (pattern found in TOKENS_AFTER_MAX_ACTIVATING_TOKEN).\n -- If you absolutely cannot make any guesses, return the first token in MAX_ACTIVATING_TOKENS.\n\n -Respond by going through each method number until you find one that helps you find an explanation for what this neuron is detecting or predicting. If a method does not help you find an explanation, briefly explain why it does not, then go on to the next method. Finally, end your response with the method number you used, the reason for your explanation, and then the explanation.\n - -Exsample: -{ -\n\nwas\nwatching\n\n\n\n\n\n\nShe\nenjoy\n\n\n\n\n\n\nwalking\nWA\nwaiting\nwas\nwe\nWHAM\nwish\nwin\nwake\nwhisper\n\n\n\n\n\n\nShe was taking a nap when her phone started ringing.\nI enjoy watching movies with my family.\n\n\n\n\nExplanation of neuron behavior: \n -Method 1 fails: MAX_ACTIVATING_TOKENS (She, enjoy) are not similar tokens.\nMethod 2 succeeds: All TOKENS_AFTER_MAX_ACTIVATING_TOKEN have a pattern in common: they all start with "w".\nExplanation: say "w" words -} - -{ -\n\nwarm\nthe\n\n\n\n\n\n\nand\nAnd\n\n\n\n\n\n\nelephant\nguitar\nmountain\nbicycle\nocean\ntelescope\ncandle\numbrella\ntornado\nbutterfly\n\n\n\n\n\n\nIt was a beautiful day outside with clear skies and warm sunshine.\nAnd the garden has roses and tulips and daisies and sunflowers blooming together.\n\n\n\n\nExplanation of neuron behavior: \n -Method 1 succeeds: All MAX_ACTIVATING_TOKENS are the word "and".\nExplanation: and -} - -{ -\n\nare\n,\n\n\n\n\n\n\nbanana\nblueberries\n\n\n\n\n\n\napple\norange\npineapple\nwatermelon\nkiwi\npeach\npear\ngrape\ncherry\nplum\n\n\n\n\n\n\nThe apple and banana are delicious foods that provide essential vitamins and nutrients.\nI enjoy eating fresh strawberries, blueberries, and mangoes during the summer months.\n\n\n\n\nExplanation of neuron behavior: \n -Method 1 succeeds: All MAX_ACTIVATING_TOKENS (banana, blueberries) are fruits.\nExplanation: fruits\n -} - -{ -\n\nwas\nplaces\n\n\n\n\n\n\nwar\nsome\n\n\n\n\n\n\n4\nfour\nfourth\n4th\nIV\nFour\nFOUR\n~4\n4.0\nquartet\n\n\n\n\n\n\nthe civil war was a major topic in history class .\n seasons of the year are winter , spring , summer , and fall or autumn in some places .\n\n\n\n\nExplanation of neuron behavior: \n -Method 1 fails: MAX_ACTIVATING_TOKENS (war, some) are not all the same token.\nMethod 2 fails: TOKENS_AFTER_MAX_ACTIVATING_TOKEN (was, places) are not all similar tokens and don't have a text pattern in common.\nMethod 3 succeeds: All TOP_POSITIVE_LOGITS are the number 4.\nExplanation: 4\n -} -""" - examples_to_show = activating_examples[: self.cfg.n_activating_examples] - next_activating_tokens = "" - max_activating_tokens = "" - plain_activating_tokens = "" - logit_activating_tokens = "" - - for i, example in enumerate(examples_to_show, 1): - next_activating_tokens = next_activating_tokens + example.display_next(self.cfg.activation_threshold) - max_activating_tokens = max_activating_tokens + example.display_max(self.cfg.activation_threshold) - plain_activating_tokens = plain_activating_tokens + example.display_plain() + "\n" - - if self.logits is not None: - for text in self.logits["top_positive"]: - logit_activating_tokens = logit_activating_tokens + text["token"] + "\n" - else: - logit_activating_tokens = next_activating_tokens - - user_prompt: str = f""" -\n\n{next_activating_tokens}\n\n\n\n\n\n{max_activating_tokens}\n\n\n\n\n\n{logit_activating_tokens}\n<\\TOP_POSITIVE_LOGITS>\n\n\n\n\n{plain_activating_tokens}\n<\\TOP_ACTIVATING_TEXTS>\n\n\nExplanation of neuron behavior: \n -""" - return system_prompt, user_prompt - - def _generate_explanation_prompt(self, activating_examples: list[TokenizedSample]) -> tuple[str, str]: - """Generate a prompt for explanation generation. - - Args: - activating_examples: List of activating examples - - Returns: - Prompt string for the LLM - """ - cot_prompt = "" - if self.cfg.include_cot: - cot_prompt += "\n\nTo explain this feature, please follow these steps:\n" - cot_prompt += "Step 1: List a couple activating and contextual tokens you find interesting. " - cot_prompt += "Search for patterns in these tokens, if there are any. Don't list more than 5 tokens.\n" - cot_prompt += "Step 2: Write down general shared features of the text examples.\n" - cot_prompt += "Step 3: Write a concise explanation of what this feature detects.\n" - - examples_prompt = """Some examples: - -{ - "steps": ["Activating token: <>. Contextual tokens: Who, ?. Pattern: <> is consistently activated, often found in sentences starting with interrogative words like 'Who' and ending with a question mark.", "Shared features include consistent activation on the word 'knows'. The surrounding text always forms a question. The questions do not seem to expect a literal answer, suggesting they are rhetorical.", "This feature activates on the word knows in rhetorical questions"], - "final_explanation": "The feature activates on the word 'knows' in rhetorical questions.", - "activation_consistency": 5, - "complexity": 4 -} - -{ - "steps": ["Activating tokens: <>, <>, <>, <>, <>. Pattern: All activating instances are on words that begin with the specific substring 'Ent'. The activation is on the 'Ent' portion itself.", "The shared feature across all examples is the presence of words starting with the capitalized substring 'Ent'. The feature appears to be case-sensitive and position-specific (start of the word). No other contextual or semantic patterns are observed."], - "final_explanation": "The feature activates on the substring 'Ent' at the start of words", - "activation_consistency": 5, - "complexity": 1 -} - -{ - "steps": ["Activating tokens: <>, <>, <>, <>, <>. Pattern: Activations highlight phrases and concepts central to economic discussions and government actions.","The examples consistently involve discussions of economic indicators, government spending, financial regulation, or international trade agreements. While most activations clearly relate to economic policies enacted or debated by governmental bodies, some activations might be on broader economic news or expert commentary where the direct link to a specific government policy is less explicit, or on related but not identical topics like corporate financial health in response to policy."], - "final_explanation": "The feature activates on text about government economic policy", - "activation_consistency": 3, - "complexity": 5 -} - -""" - system_prompt: str = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. The activating words in each document are indicated with << ... >>. We will give you a list of documents on which the feature activates, in order from most strongly activating to least strongly activating. - -Your task is to: - -First, Summarize the Activation: Look at the parts of the document the feature activates for and summarize in a single sentence what the feature is activating on. Try not to be overly specific in your explanation. Note that some features will activate only on specific words or substrings, but others will activate on most/all words in a sentence provided that sentence contains some particular concept. Your explanation should cover most or all activating words (for example, don't give an explanation which is specific to a single word if all words in a sentence cause the feature to activate). Pay attention to things like the capitalization and punctuation of the activating words or concepts, if that seems relevant. Keep the explanation as short and simple as possible, limited to 20 words or less. Omit punctuation and formatting. You should avoid giving long lists of words.{cot_prompt} - -Second, Assess Activation Consistency: Based on your summary and the provided examples, evaluate the consistency of the feature's activation. Return your assessment as a single integer from the following scale: - -5: Clear pattern with no deviating examples -4: Clear pattern with one or two deviating examples -3: Clear overall pattern but quite a few examples not fitting that pattern -2: Broad consistent theme but lacking structure -1: No discernible pattern - -Third, Assess Feature Complexity: Based on your summary and the nature of the activation, evaluate the complexity of the feature. Return your assessment as a single integer from the following scale: - -5: Rich feature firing on diverse contexts with an interesting unifying theme, e.g., "feelings of togetherness" -4: Feature relating to high-level semantic structure, e.g., "return statements in code" -3: Moderate complexity, such as a phrase, category, or tracking sentence structure, e.g., "website URLs" -2: Single word or token feature but including multiple languages or spelling, e.g., "mentions of dog" -1: Single token feature, e.g., "the token '('" - -Your output should be a JSON object that has the following fields: `steps`, `final_explanation`, `activation_consistency`, `complexity`. `steps` should be an array of strings with a length not exceeding 3, each representing a step in the chain-of-thought process. `final_explanation` should be a string in the form of 'This feature activates on... '. `activation_consistency` should be an integer between 1 and 5, representing the consistency of the feature. `complexity` should be an integer between 1 and 5, representing the complexity of the feature. - -{examples_prompt} -""" - - user_prompt = "The activating documents are given below:\n\n" - # Select a subset of examples to show - examples_to_show = activating_examples[: self.cfg.n_activating_examples] - - for i, example in enumerate(examples_to_show, 1): - highlighted = example.display_highlighted(self.cfg.activation_threshold) - user_prompt += f"Example {i}: {highlighted}\n\n" - - return system_prompt, user_prompt - - def generate_explanation(self, activating_examples: list[TokenizedSample]) -> dict[str, Any]: - """Generate an explanation for a feature based on activating examples. - - Args: - activating_examples: List of examples where the feature activates - - Returns: - Dictionary with explanation and metadata - """ - if self.cfg.explainer_type is ExplainerType.OPENAI: - system_prompt, user_prompt = self._generate_explanation_prompt(activating_examples) - else: - system_prompt, user_prompt = self._generate_explanation_prompt_neuronpedia(activating_examples) - start_time = time.time() - # print(f'{system_prompt=}') - print(f"{user_prompt=}") - - if self.cfg.explainer_type is ExplainerType.OPENAI: - response = self.explainer_client.chat.completions.create( - model=self.cfg.openai_model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - response_format={"type": "json_object"}, - ) - - assert response.choices[0].message.content is not None, ( - f"No explanation returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" - ) - explanation = json_repair.loads(response.choices[0].message.content) - else: - response = self.explainer_client.chat.completions.create( - model=self.cfg.openai_model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - ) - - def extract_explanation(s: str | None): - if s is None: - return None - keyword = "Explanation: " - start_index = s.find(keyword) - if start_index == -1: - return None - else: - return s[start_index + len(keyword) :] - - explanation = { - "final_explanation": extract_explanation(response.choices[0].message.content), - "activation_consistency": 5, - "complexity": 5, - } - response_time = time.time() - start_time - return { - "user_prompt": user_prompt, - "system_prompt": system_prompt, - "response": explanation, - "time": response_time, - } - - def _generate_detection_prompt( - self, explanation: dict[str, Any], examples: list[TokenizedSample] - ) -> tuple[str, str]: - """Generate a prompt for detection evaluation. - - Args: - explanation: The explanation to evaluate - examples: List of examples (mix of activating and non-activating) - - Returns: - Prompt string for the LLM - """ - system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. You will have to return a boolean list of the examples where you think the feature should activate at least once, on ANY of the words or substrings in the document, true if it does, false if it doesn't. Try not to be overly specific in your interpretation of the explanation.""" - system_prompt += """ -Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example. -""" - user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n" - - for i, example in enumerate(examples, 1): - user_prompt += f"Example {i}: {example.display_plain()}\n" - - return system_prompt, user_prompt - - def evaluate_explanation_detection( - self, - explanation: dict[str, Any], - activating_examples: list[TokenizedSample], - non_activating_examples: list[TokenizedSample], - ) -> dict[str, Any]: - """Evaluate an explanation using the detection method. - - Args: - explanation: The explanation to evaluate - activating_examples: Examples where the feature activates - non_activating_examples: Examples where the feature doesn't activate - - Returns: - Dictionary with evaluation results - """ - # Select a subset of examples - n_activating = min(self.cfg.detection_n_examples, len(activating_examples)) - n_non_activating = min(self.cfg.detection_n_examples, len(non_activating_examples)) - - test_activating = random.sample(activating_examples, n_activating) if n_activating > 0 else [] - test_non_activating = random.sample(non_activating_examples, n_non_activating) if n_non_activating > 0 else [] - - # Mix and shuffle examples - all_examples = test_activating + test_non_activating - if len(all_examples) < self.cfg.detection_n_examples: - return { - "method": "detection", - "prompt": "", - "response": "", - "ground_truth": [], - "predictions": [], - "metrics": { - "accuracy": 0, - "precision": 0, - "recall": 0, - "f1": 0, - "balanced_accuracy": 0, - }, - "passed": False, - "time": 0, - } - - random.shuffle(all_examples) - - # Ground truth for each example (1 for activating, 0 for non-activating) - ground_truth = [1 if ex in test_activating else 0 for ex in all_examples] - - # Generate prompt - system_prompt, user_prompt = self._generate_detection_prompt(explanation, all_examples) - - # Get response from OpenAI - start_time = time.time() - response = self.explainer_client.chat.completions.create( - model=self.cfg.openai_model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - response_format={"type": "json_object"}, - ) - assert response.choices[0].message.content is not None, ( - f"No detection response returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" - ) - detection_response: dict[str, Any] = json_repair.loads(response.choices[0].message.content) # type: ignore - # print(f"Detection for feature :\n{detection_response}\n\n") - predictions: list[bool] = detection_response["evaluation_results"] - response_time = time.time() - start_time - - # Pad predictions if needed - predictions = predictions[: len(ground_truth)] - if len(predictions) < len(ground_truth): - predictions.extend([False] * (len(ground_truth) - len(predictions))) - - # Calculate metrics - tp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 1 and pred == 1) - tn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 0 and pred == 0) - fp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 0 and pred == 1) - fn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt == 1 and pred == 0) - - accuracy = (tp + tn) / len(ground_truth) if ground_truth else 0 - precision = tp / (tp + fp) if (tp + fp) > 0 else 0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 - balanced_accuracy = ((tp / (tp + fn) if (tp + fn) > 0 else 0) + (tn / (tn + fp) if (tn + fp) > 0 else 0)) / 2 - - return { - "method": "detection", - "prompt": system_prompt + "\n\n" + user_prompt, - "response": detection_response, - "ground_truth": ground_truth, - "predictions": predictions, - "metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall, - "f1": f1, - "balanced_accuracy": balanced_accuracy, - }, - "passed": balanced_accuracy >= 0.7, # Arbitrary threshold for passing - "time": response_time, - } - - def _generate_fuzzing_prompt( - self, - explanation: dict[str, Any], - examples: list[tuple[TokenizedSample, bool]], # (sample, is_correctly_marked) - ) -> tuple[str, str]: - """Generate a prompt for fuzzing evaluation. - - Args: - explanation: The explanation to evaluate - examples: List of tuples (example, is_correctly_marked) - - Returns: - Prompt string for the LLM - """ - system_prompt = f"""We're studying features in a neural network. Each feature activates on some particular word/words/substring/concept in a short document. You will be given a short explanation of what this feature activates for, and then be shown {len(examples)} example sequences in random order. In each example, text segments highlighted with << >> are presented as activating the feature as described in the explanation. You will have to return a boolean list of the examples where you think the highlighted parts CORRECTLY correspond to the explanation, true if they do, false if they don't. Try not to be overly specific in your interpretation of the explanation.""" - system_prompt += """ -Your output should be a JSON object that has the following fields: `steps`, `evaluation_results`. `steps` should be an array of strings, each representing a step in the chain-of-thought process within 50 words. `evaluation_results` should be an array of booleans, each representing whether the feature should activate on the corresponding example. -""" - user_prompt = f"Here is the explanation:\n\n{explanation['final_explanation']}\n\nHere are the examples:\n\n" - - for i, (example, _) in enumerate(examples, 1): - highlighted = example.display_highlighted(self.cfg.activation_threshold) - user_prompt += f"Example {i}: {highlighted}\n" - - return system_prompt, user_prompt - - def _create_incorrectly_marked_example(self, sample: TokenizedSample) -> TokenizedSample: - """Create an incorrectly marked version of an example. - - Args: - sample: The original sample - - Returns: - A copy of the sample with incorrect highlighting - """ - # Count how many tokens would be highlighted in the correct example - threshold = self.cfg.activation_threshold - n_highlighted = sum(1 for seg in sample.segments if seg.activation > threshold * sample.max_activation) - - def highlight_random_tokens(sample: TokenizedSample, n_highlighted: int) -> TokenizedSample: - non_activating_indices = [ - i for i, seg in enumerate(sample.segments) if seg.activation < threshold * sample.max_activation - ] - highlight_indices = random.sample(non_activating_indices, min(n_highlighted, len(non_activating_indices))) - segments = [ - Segment(seg.text, sample.max_activation if i in highlight_indices else 0) - for i, seg in enumerate(sample.segments) - ] - return TokenizedSample(segments, sample.max_activation) - - n_to_highlight = max(3, n_highlighted) # Highlight at least 3 tokens - return highlight_random_tokens(sample, n_to_highlight) - - def evaluate_explanation_fuzzing( - self, explanation: dict[str, Any], activating_examples: list[TokenizedSample] - ) -> dict[str, Any]: - """Evaluate an explanation using the fuzzing method. - - Args: - explanation: The explanation to evaluate - activating_examples: Examples where the feature activates - - Returns: - Dictionary with evaluation results - """ - if len(activating_examples) < self.cfg.fuzzing_n_examples: - # Not enough examples, return empty result - return { - "method": "fuzzing", - "prompt": "", - "response": "", - "ground_truth": [], - "predictions": [], - "metrics": { - "accuracy": 0, - "precision": 0, - "recall": 0, - "f1": 0, - "balanced_accuracy": 0, - }, - "passed": False, - "time": 0, - } - - # Prepare examples: - # - Correctly marked examples (original) - # - Incorrectly marked examples (with wrong parts highlighted) - n_correct = self.cfg.fuzzing_decile_correct - n_incorrect = self.cfg.fuzzing_decile_incorrect - - # Get a sample of activating examples - sample_examples = random.sample(activating_examples, min(n_correct + n_incorrect, len(activating_examples))) - - # Split into correct and incorrect - correct_examples = sample_examples[:n_correct] - incorrect_candidates = sample_examples[n_correct:] - - # Create incorrectly marked versions - incorrect_examples = [self._create_incorrectly_marked_example(ex) for ex in incorrect_candidates] - - # Combine and mark with ground truth - examples_with_labels = [(ex, True) for ex in correct_examples] + [(ex, False) for ex in incorrect_examples] - - # Shuffle - random.shuffle(examples_with_labels) - - # Generate prompt - system_prompt, user_prompt = self._generate_fuzzing_prompt(explanation, examples_with_labels) - - # Get response from OpenAI - start_time = time.time() - response = self.explainer_client.chat.completions.create( - model=self.cfg.openai_model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - response_format={"type": "json_object"}, - ) - assert response.choices[0].message.content is not None, ( - f"No fuzzing response returned from OpenAI\n\nsystem_prompt: {system_prompt}\n\nuser_prompt: {user_prompt}\n\nresponse: {response}" - ) - fuzzing_response: dict[str, Any] = json_repair.loads(response.choices[0].message.content) # type: ignore - # print(f"Fuzzing for feature :\n{fuzzing_response}\n\n") - # Parse response (CORRECT/INCORRECT for each example) - predictions: list[bool] = fuzzing_response["evaluation_results"] - response_time = time.time() - start_time - # Pad predictions if needed - predictions = predictions[: len(examples_with_labels)] - if len(predictions) < len(examples_with_labels): - predictions.extend([False] * (len(examples_with_labels) - len(predictions))) - - # Extract ground truth - ground_truth = [is_correct for _, is_correct in examples_with_labels] - - # Calculate metrics - tp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is True and pred is True) - tn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is False and pred is False) - fp = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is False and pred is True) - fn = sum(1 for gt, pred in zip(ground_truth, predictions) if gt is True and pred is False) - - accuracy = (tp + tn) / len(ground_truth) if ground_truth else 0 - precision = tp / (tp + fp) if (tp + fp) > 0 else 0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 - balanced_accuracy = ((tp / (tp + fn) if (tp + fn) > 0 else 0) + (tn / (tn + fp) if (tn + fp) > 0 else 0)) / 2 - - return { - "method": "fuzzing", - "prompt": user_prompt, - "response": fuzzing_response, - "ground_truth": ground_truth, - "predictions": predictions, - "metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall, - "f1": f1, - "balanced_accuracy": balanced_accuracy, - }, - "passed": balanced_accuracy >= 0.7, # Arbitrary threshold for passing - "time": response_time, - } - - def interpret_single_feature( - self, - feature: FeatureRecord, - model: LanguageModel, - datasets: Callable[[str, int, int], Dataset], - analysis_name: str = "default", - ) -> dict[str, Any]: - """Generate and evaluate explanations for multiple features. - - Args: - feature: Feature to interpret - model: Language model to use for generating activations - datasets: Dataset to sample non-activating examples from - analysis_name: Name of the analysis to use - - Returns: - Dictionary mapping feature indices to their interpretation results - """ - - start_time = time.time() - response_time = 0 - - # if self.cfg.explainer_type is ExplainerType.NEURONPEDIA: - self.logits = feature.logits - - activating_examples, non_activating_examples = self.get_feature_examples( - feature=feature, - model=model, - datasets=datasets, - analysis_name=analysis_name, - max_length=self.cfg.max_length, - ) - - # print(f'{len(activating_examples)=} {len(non_activating_examples)=}') - - # Generate explanation for the feature - explanation_result = self.generate_explanation(activating_examples) - explanation: dict[str, Any] = explanation_result["response"] - response_time += explanation_result["time"] - # print(f"Explanation for feature {feature.index}:\n{explanation}\n\n") - # Evaluate explanation - evaluation_results = [] - - if ScorerType.DETECTION in self.cfg.scorer_type: - detection_result = self.evaluate_explanation_detection( - explanation, activating_examples, non_activating_examples - ) - # print(f"Detection result for feature {feature.index}:\n{detection_result}\n\n") - evaluation_results.append(detection_result) - # print(detection_result) - response_time += detection_result["time"] - - if ScorerType.FUZZING in self.cfg.scorer_type: - fuzzing_result = self.evaluate_explanation_fuzzing(explanation, activating_examples) - # print(f"Fuzzing result for feature {feature.index}:\n{fuzzing_result}\n\n") - evaluation_results.append(fuzzing_result) - response_time += fuzzing_result["time"] - - total_time = time.time() - start_time - - return { - "analysis_name": analysis_name, - "explanation": explanation["final_explanation"], - "complexity": explanation["complexity"], - "consistency": explanation["activation_consistency"], - "explanation_details": { - k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in explanation_result.items() - }, - "evaluations": [ - {k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in eval_result.items()} - for eval_result in evaluation_results - ], - "passed": any(eval_result["passed"] for eval_result in evaluation_results), - "time": { - "total": total_time, - "response": response_time, - }, - } - - def interpret_features( - self, - sae_name: str, - sae_series: str, - feature_indices: list[int], - model: LanguageModel, - datasets: Callable[[str, int, int], Dataset], - analysis_name: str = "default", - ) -> Generator[dict[str, Any], None, None]: - """Generate and evaluate explanations for multiple features. - - Args: - sae_name: Name of the SAE - sae_series: Series of the SAE - feature_indices: Indices of the features to interpret - model: Language model to use for generating activations - datasets: Dataset to sample non-activating examples from - analysis_name: Name of the analysis to use - - Returns: - Dictionary mapping feature indices to their interpretation results - """ - - for feature_index in feature_indices: - feature = self.mongo_client.get_feature(sae_name, sae_series, feature_index) - if feature is not None and feature.interpretation is None: - yield { - "feature_index": feature.index, - "sae_name": sae_name, - "sae_series": sae_series, - } | self.interpret_single_feature(feature, model, datasets, analysis_name) diff --git a/src/lm_saes/circuit/autointerp4graph.py b/src/lm_saes/circuit/autointerp4graph.py deleted file mode 100644 index 46f121ea..00000000 --- a/src/lm_saes/circuit/autointerp4graph.py +++ /dev/null @@ -1,179 +0,0 @@ -import json -from functools import lru_cache - -from datasets import Dataset -from pydantic_settings import BaseSettings - -from lm_saes.analysis.feature_interpreter import AutoInterpConfig, FeatureInterpreter -from lm_saes.config import LanguageModelConfig, MongoDBConfig -from lm_saes.database import MongoClient -from lm_saes.resource_loaders import load_dataset_shard, load_model -from lm_saes.utils.logging import get_logger, setup_logging - -logger = get_logger("runners.autointerp4graph") - - -class AutoInterp4GraphSettings(BaseSettings): - """Settings for automatic interpretation of SAE features.""" - - graph_path: str - """The json file path of graph to demonstrate.""" - - # sae_name: str - # """Name of the SAE model to interpret. Use as identifier for the SAE model in the database.""" - - sae_series: str - """Series of the SAE model to interpret. Use as identifier for the SAE model in the database.""" - - model: LanguageModelConfig - """Configuration for the language model used to generate activations.""" - - model_name: str - """Name of the model to load.""" - - auto_interp: AutoInterpConfig - """Configuration for the auto-interpretation process.""" - - mongo: MongoDBConfig - """Configuration for the MongoDB database.""" - - analysis_name: str = "default" - """Name of the analysis to use for interpretation.""" - - max_workers: int = 10 - """Maximum number of workers to use for interpretation.""" - - cover: bool = False - """Whether cover the generated interp""" - - -def get_feature(graph_path): - """ - Process node data from JSON file, and organize into specified list based on feature_type - - Args: - graph_path (str): Path to the JSON file - - Returns: - list: List containing dictionaries, each with format {"sae_name": str, "feature_idx": int} - """ - try: - # Read JSON file - with open(graph_path, "r", encoding="utf-8") as f: - data = json.load(f) - except FileNotFoundError: - print(f"Error: File {graph_path} does not exist") - return [] - except json.JSONDecodeError: - print(f"Error: File {graph_path} is not a valid JSON format") - return [] - - # Extract metadata and nodes from data - metadata = data.get("metadata", {}) - nodes = data.get("nodes", []) - - # Get analysis names from metadata - lorsa_analysis_name = metadata.get("lorsa_analysis_name", "L{}Lorsa") - clt_analysis_name = metadata.get("clt_analysis_name", "L{}CLT-k1024") - - result_list = [] - - for node in nodes: - node_id = node.get("node_id") - feature_type = node.get("feature_type") - - # Skip nodes missing required fields - if not node_id or not feature_type: - continue - - # Split node_id into components - parts = node_id.split("_") - if len(parts) < 3: - continue # Skip if node_id format is invalid - - try: - # Extract layer_id and feature_id from node_id parts - layer_id = int(parts[0]) - feature_id = int(parts[1]) - except ValueError: - continue # Skip if conversion to integer fails - - # Process based on feature_type - if feature_type == "lorsa": - # For lorsa: layer_id = layer_id // 2, use clt_analysis_name - new_layer_id = layer_id // 2 - sae_name = lorsa_analysis_name.format(new_layer_id) - elif feature_type == "cross layer transcoder": - # For cross layer transcoder: layer_id = (layer_id - 1) // 2, use lorsa_analysis_name - new_layer_id = (layer_id - 1) // 2 - sae_name = clt_analysis_name.format(new_layer_id) - else: - continue # Skip other feature types - - # Add processed result to the list - result_list.append({"sae_name": sae_name, "feature_idx": feature_id}) - - return result_list - - -def auto_interp4graph(settings: AutoInterp4GraphSettings): - """Automatically interpret features using LLMs. - - Args: - settings: Configuration - """ - setup_logging(level="INFO") - - # Set up MongoDB client - mongo_client = MongoClient(settings.mongo) - - language_model = load_model(settings.model) - - feature_list = get_feature(settings.graph_path) - - interpreter = FeatureInterpreter(settings.auto_interp, mongo_client) - - @lru_cache(maxsize=None) - def get_dataset(dataset_name: str, shard_idx: int, n_shards: int) -> Dataset: - dataset_cfg = mongo_client.get_dataset_cfg(dataset_name) - assert dataset_cfg is not None, f"Dataset {dataset_name} not found" - dataset = load_dataset_shard(dataset_cfg, shard_idx, n_shards) - return dataset - - for todo_feature in feature_list: - sae_name = todo_feature["sae_name"] - feature_idx = todo_feature["feature_idx"] - - feature = mongo_client.get_feature(sae_name, settings.sae_series, feature_idx) - if feature is not None: - if feature.interpretation is None or settings.cover: - result = { - "feature_index": feature.index, - "sae_name": sae_name, - "sae_series": settings.sae_series, - } | interpreter.interpret_single_feature(feature, language_model, get_dataset, settings.analysis_name) - - interpretation = { - "text": result["explanation"], - "validation": [ - {"method": eval_result["method"], "passed": eval_result["passed"], "detail": eval_result} - for eval_result in result["evaluations"] - ], - "complexity": result["complexity"], - "consistency": result["consistency"], - "detail": result["explanation_details"], - "passed": result["passed"], - "time": result["time"], - } - logger.info( - f"Updating feature {result['feature_index']} in {sae_name}\nTime: {result['time']}\nExplanation: {interpretation['text']}" - ) - mongo_client.update_feature( - sae_name, result["feature_index"], {"interpretation": interpretation}, settings.sae_series - ) - elif feature is not None: - logger.info( - f"Already interp feature {feature_idx} in {sae_name}\nExplanation: {feature.interpretation}" - ) - else: - logger.info(f"Feature {feature_idx} in {sae_name} does not exist. Please check it.") diff --git a/src/lm_saes/database.py b/src/lm_saes/database.py index 6eba34ca..d2cae0d6 100644 --- a/src/lm_saes/database.py +++ b/src/lm_saes/database.py @@ -8,6 +8,7 @@ import pymongo.errors from bson import ObjectId from pydantic import BaseModel +from tqdm import tqdm from lm_saes.config import ( BaseSAEConfig, @@ -387,7 +388,7 @@ def add_feature_analysis(self, name: str, sae_name: str, sae_series: str, analys self.enable_gridfs() operations = [] - for i, feature_analysis in enumerate(analysis): + for i, feature_analysis in enumerate(tqdm(analysis, desc="Adding feature analyses to MongoDB...")): # Convert numpy arrays to GridFS references processed_analysis = self._to_gridfs(feature_analysis) update_operation = pymongo.UpdateOne( @@ -452,7 +453,7 @@ def update_feature(self, sae_name: str, feature_index: int, update_data: dict, s def update_features(self, sae_name: str, sae_series: str, update_data: list[dict], start_idx: int = 0): operations = [] - for i, feature_update in enumerate(update_data): + for i, feature_update in enumerate(tqdm(update_data, desc="Updating features in MongoDB...")): update_operation = pymongo.UpdateOne( {"sae_name": sae_name, "sae_series": sae_series, "index": start_idx + i}, {"$set": feature_update}, diff --git a/src/lm_saes/runners/autointerp.py b/src/lm_saes/runners/autointerp.py index 3b60a639..b5d45d29 100644 --- a/src/lm_saes/runners/autointerp.py +++ b/src/lm_saes/runners/autointerp.py @@ -1,17 +1,18 @@ """Module for automatic interpretation of SAE features.""" -import concurrent.futures +import asyncio from functools import lru_cache -from typing import Any, Optional +from typing import Optional from datasets import Dataset from pydantic_settings import BaseSettings +from tqdm.asyncio import tqdm -from lm_saes.analysis.feature_interpreter import AutoInterpConfig, FeatureInterpreter +from lm_saes.analysis.autointerp import AutoInterpConfig, FeatureInterpreter from lm_saes.config import LanguageModelConfig, MongoDBConfig from lm_saes.database import MongoClient from lm_saes.resource_loaders import load_dataset_shard, load_model -from lm_saes.utils.logging import get_logger, setup_logging +from lm_saes.utils.logging import get_logger logger = get_logger("runners.autointerp") @@ -40,12 +41,6 @@ class AutoInterpSettings(BaseSettings): features: Optional[list[int]] = None """List of specific feature indices to interpret. If None, will interpret all features.""" - feature_range: Optional[list[int]] = None - """Range of feature indices to interpret [start, end]. If None, will interpret all features.""" - - top_k_features: Optional[int] = None - """Number of top activating features to interpret. If None, will use the features or feature_range.""" - analysis_name: str = "default" """Name of the analysis to use for interpretation.""" @@ -53,10 +48,13 @@ class AutoInterpSettings(BaseSettings): """Maximum number of workers to use for interpretation.""" -def interpret_feature(args: dict[str, Any]): - settings: AutoInterpSettings = args["settings"] - feature_indices: list[int] = args["feature_indices"] +async def interpret_feature(settings: AutoInterpSettings, show_progress: bool = True): + """Interpret features using async API calls for maximum concurrency. + Args: + settings: Configuration for feature interpretation + show_progress: Whether to show progress bar (requires tqdm) + """ @lru_cache(maxsize=None) def get_dataset(dataset_name: str, shard_idx: int, n_shards: int) -> Dataset: dataset_cfg = mongo_client.get_dataset_cfg(dataset_name) @@ -67,77 +65,64 @@ def get_dataset(dataset_name: str, shard_idx: int, n_shards: int) -> Dataset: mongo_client = MongoClient(settings.mongo) language_model = load_model(settings.model) interpreter = FeatureInterpreter(settings.auto_interp, mongo_client) - for result in interpreter.interpret_features( + + # Set up progress tracking + progress_bar = None + processed_count = 0 + total_count = None + + def progress_callback(processed: int, total: int, current_feature: int) -> None: + """Update progress bar and log progress. + + Args: + processed: Number of features processed (completed + skipped + failed) + total: Total number of features to process + current_feature: Index of the feature currently being processed + """ + nonlocal processed_count, total_count, progress_bar + processed_count = processed + if total_count is None: + total_count = total + if show_progress: + progress_bar = tqdm( + total=total, + desc="Interpreting features", + unit="feature", + dynamic_ncols=True, + initial=0, + ) + + if progress_bar is not None: + progress_bar.n = processed + progress_bar.refresh() + progress_bar.set_postfix({"current": current_feature}) + + async for result in interpreter.interpret_features( sae_name=settings.sae_name, sae_series=settings.sae_series, - feature_indices=feature_indices, model=language_model, datasets=get_dataset, analysis_name=settings.analysis_name, + feature_indices=settings.features, + max_concurrent=settings.max_workers, + progress_callback=progress_callback, ): interpretation = { "text": result["explanation"], - "validation": [ - {"method": eval_result["method"], "passed": eval_result["passed"], "detail": eval_result} - for eval_result in result["evaluations"] - ], - "complexity": result["complexity"], - "consistency": result["consistency"], - "detail": result["explanation_details"], - "passed": result["passed"], - "time": result["time"], } - logger.info( - f"Updating feature {result['feature_index']}\nTime: {result['time']}\nExplanation: {interpretation['text']}\nComplexity: {interpretation['complexity']}\nConsistency: {interpretation['consistency']}\nPassed: {interpretation['passed']}\n\n" - ) + assert interpretation['text'] is not None mongo_client.update_feature( settings.sae_name, result["feature_index"], {"interpretation": interpretation}, settings.sae_series ) + if progress_bar is not None: + progress_bar.close() + logger.info(f"Completed interpretation: {processed_count}/{total_count} features processed") -def auto_interp(settings: AutoInterpSettings) -> None: - """Automatically interpret SAE features using LLMs. +def auto_interp(settings: AutoInterpSettings): + """Synchronous wrapper for interpret_feature. Args: - settings: Configuration settings for auto-interpretation + settings: Configuration for feature interpretation """ - setup_logging(level="INFO") - - # Set up MongoDB client - mongo_client = MongoClient(settings.mongo) - - # Determine which features to interpret - if settings.top_k_features: - # Get top k most frequently activating features - act_times = mongo_client.get_feature_act_times(settings.sae_name, settings.sae_series, settings.analysis_name) - if not act_times: - raise ValueError(f"No feature activation times found for {settings.sae_name}/{settings.sae_series}") - sorted_features = sorted(act_times.items(), key=lambda x: x[1], reverse=True) - feature_indices = [idx for idx, _ in sorted_features[: settings.top_k_features]] - elif settings.feature_range: - # Use feature range - feature_indices = list(range(settings.feature_range[0], settings.feature_range[1] + 1)) - elif settings.features: - # Use specific features - feature_indices = settings.features - else: - # Use all features (be careful, this could be a lot!) - max_feature_acts = mongo_client.get_max_feature_acts( - settings.sae_name, settings.sae_series, settings.analysis_name - ) - if not max_feature_acts: - raise ValueError(f"No feature activations found for {settings.sae_name}/{settings.sae_series}") - feature_indices = list(max_feature_acts.keys()) - - # Load resources - logger.info(f"Loading SAE model: {settings.sae_name}/{settings.sae_series}") - logger.info(f"Loading language model: {settings.model_name}") - - chunk_size = len(feature_indices) // settings.max_workers + 1 - feature_batches = [feature_indices[i : i + chunk_size] for i in range(0, len(feature_indices), chunk_size)] - args_batches = [{"feature_indices": feature_indices, "settings": settings} for feature_indices in feature_batches] - - with concurrent.futures.ThreadPoolExecutor(max_workers=settings.max_workers) as executor: - list(executor.map(interpret_feature, args_batches)) - - logger.info("Done!") + asyncio.run(interpret_feature(settings)) diff --git a/tests/unit/test_feature_interpreter.py b/tests/unit/test_feature_interpreter.py index 197d5a53..182dd4c5 100644 --- a/tests/unit/test_feature_interpreter.py +++ b/tests/unit/test_feature_interpreter.py @@ -3,7 +3,7 @@ import pytest import torch -from lm_saes.analysis.feature_interpreter import ( +from lm_saes.analysis.autointerp import ( AutoInterpConfig, AutoInterpEvaluation, AutoInterpExplanation, diff --git a/ui/bun.lockb b/ui/bun.lockb index 23b3566c..d2862062 100755 Binary files a/ui/bun.lockb and b/ui/bun.lockb differ diff --git a/ui/src/components/app/sample.tsx b/ui/src/components/app/sample.tsx index 5d90ff17..5b39509a 100644 --- a/ui/src/components/app/sample.tsx +++ b/ui/src/components/app/sample.tsx @@ -26,7 +26,7 @@ export const Sample = ({ const [folded, setFolded] = useState(true); return ( -
+
setFolded(!folded) : undefined} diff --git a/ui/src/components/feature/interpret.tsx b/ui/src/components/feature/interpret.tsx index 6c2f379b..c9de1c6f 100644 --- a/ui/src/components/feature/interpret.tsx +++ b/ui/src/components/feature/interpret.tsx @@ -314,7 +314,7 @@ export const FeatureInterpretation = ({ feature }: { feature: Feature }) => {
- {interpretation?.validation.map((validation, i) => ( + {interpretation?.validation?.map((validation, i) => (
{validation.passed ? ( diff --git a/ui/src/types/feature.ts b/ui/src/types/feature.ts index 24b484ef..6ca3a08c 100644 --- a/ui/src/types/feature.ts +++ b/ui/src/types/feature.ts @@ -48,7 +48,7 @@ export const InterpretationSchema = z.object({ }) .optional(), }) - ), + ).optional(), detail: z .object({ userPrompt: z.string(),