Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 190 additions & 146 deletions src/lm_saes/analysis/feature_interpreter.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pymongo.errors
from bson import ObjectId
from pydantic import BaseModel
from tqdm import tqdm

from lm_saes.config import (
BaseSAEConfig,
Expand Down Expand Up @@ -387,7 +388,7 @@ def add_feature_analysis(self, name: str, sae_name: str, sae_series: str, analys
self.enable_gridfs()

operations = []
for i, feature_analysis in enumerate(analysis):
for i, feature_analysis in enumerate(tqdm(analysis, desc="Adding feature analyses to MongoDB...")):
# Convert numpy arrays to GridFS references
processed_analysis = self._to_gridfs(feature_analysis)
update_operation = pymongo.UpdateOne(
Expand Down Expand Up @@ -452,7 +453,7 @@ def update_feature(self, sae_name: str, feature_index: int, update_data: dict, s

def update_features(self, sae_name: str, sae_series: str, update_data: list[dict], start_idx: int = 0):
operations = []
for i, feature_update in enumerate(update_data):
for i, feature_update in enumerate(tqdm(update_data, desc="Updating features in MongoDB...")):
update_operation = pymongo.UpdateOne(
{"sae_name": sae_name, "sae_series": sae_series, "index": start_idx + i},
{"$set": feature_update},
Expand Down
153 changes: 70 additions & 83 deletions src/lm_saes/runners/autointerp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Module for automatic interpretation of SAE features."""

import concurrent.futures
import asyncio
from functools import lru_cache
from typing import Any, Optional
from typing import Optional

from datasets import Dataset
from pydantic_settings import BaseSettings
from tqdm.asyncio import tqdm

from lm_saes.analysis.feature_interpreter import AutoInterpConfig, FeatureInterpreter
from lm_saes.config import LanguageModelConfig, MongoDBConfig
from lm_saes.database import MongoClient
from lm_saes.resource_loaders import load_dataset_shard, load_model
from lm_saes.utils.logging import get_logger, setup_logging
from lm_saes.utils.logging import get_logger

logger = get_logger("runners.autointerp")

Expand Down Expand Up @@ -40,23 +41,20 @@ class AutoInterpSettings(BaseSettings):
features: Optional[list[int]] = None
"""List of specific feature indices to interpret. If None, will interpret all features."""

feature_range: Optional[list[int]] = None
"""Range of feature indices to interpret [start, end]. If None, will interpret all features."""

top_k_features: Optional[int] = None
"""Number of top activating features to interpret. If None, will use the features or feature_range."""

analysis_name: str = "default"
"""Name of the analysis to use for interpretation."""

max_workers: int = 10
"""Maximum number of workers to use for interpretation."""


def interpret_feature(args: dict[str, Any]):
settings: AutoInterpSettings = args["settings"]
feature_indices: list[int] = args["feature_indices"]
async def interpret_feature(settings: AutoInterpSettings, show_progress: bool = True):
"""Interpret features using async API calls for maximum concurrency.

Args:
settings: Configuration for feature interpretation
show_progress: Whether to show progress bar (requires tqdm)
"""
@lru_cache(maxsize=None)
def get_dataset(dataset_name: str, shard_idx: int, n_shards: int) -> Dataset:
dataset_cfg = mongo_client.get_dataset_cfg(dataset_name)
Expand All @@ -67,77 +65,66 @@ def get_dataset(dataset_name: str, shard_idx: int, n_shards: int) -> Dataset:
mongo_client = MongoClient(settings.mongo)
language_model = load_model(settings.model)
interpreter = FeatureInterpreter(settings.auto_interp, mongo_client)
for result in interpreter.interpret_features(
sae_name=settings.sae_name,
sae_series=settings.sae_series,
feature_indices=feature_indices,
model=language_model,
datasets=get_dataset,
analysis_name=settings.analysis_name,
):
interpretation = {
"text": result["explanation"],
"validation": [
{"method": eval_result["method"], "passed": eval_result["passed"], "detail": eval_result}
for eval_result in result["evaluations"]
],
"complexity": result["complexity"],
"consistency": result["consistency"],
"detail": result["explanation_details"],
"passed": result["passed"],
"time": result["time"],
}
logger.info(
f"Updating feature {result['feature_index']}\nTime: {result['time']}\nExplanation: {interpretation['text']}\nComplexity: {interpretation['complexity']}\nConsistency: {interpretation['consistency']}\nPassed: {interpretation['passed']}\n\n"
)
mongo_client.update_feature(
settings.sae_name, result["feature_index"], {"interpretation": interpretation}, settings.sae_series
)


def auto_interp(settings: AutoInterpSettings) -> None:
"""Automatically interpret SAE features using LLMs.

# Set up progress tracking
progress_bar = None
processed_count = 0
total_count = None

def progress_callback(processed: int, total: int, current_feature: int) -> None:
"""Update progress bar and log progress.

Args:
processed: Number of features processed (completed + skipped + failed)
total: Total number of features to process
current_feature: Index of the feature currently being processed
"""
nonlocal processed_count, total_count, progress_bar
processed_count = processed
if total_count is None:
total_count = total
if show_progress:
progress_bar = tqdm(
total=total,
desc="Interpreting features",
unit="feature",
dynamic_ncols=True,
initial=0,
)

if progress_bar is not None:
progress_bar.n = processed
progress_bar.refresh()
progress_bar.set_postfix({"current": current_feature})

try:
async for result in interpreter.interpret_features(
sae_name=settings.sae_name,
sae_series=settings.sae_series,
model=language_model,
datasets=get_dataset,
analysis_name=settings.analysis_name,
feature_indices=settings.features,
max_concurrent=settings.max_workers,
progress_callback=progress_callback,
):
interpretation = {
"text": result["explanation"],
}
assert interpretation['text'] is not None
mongo_client.update_feature(
settings.sae_name, result["feature_index"], {"interpretation": interpretation}, settings.sae_series
)
finally:
if progress_bar is not None:
progress_bar.close()
logger.info(f"Completed interpretation: {processed_count}/{total_count} features processed")


def auto_interp(settings: AutoInterpSettings):
"""Synchronous wrapper for interpret_feature.

Args:
settings: Configuration settings for auto-interpretation
settings: Configuration for feature interpretation
"""
setup_logging(level="INFO")

# Set up MongoDB client
mongo_client = MongoClient(settings.mongo)

# Determine which features to interpret
if settings.top_k_features:
# Get top k most frequently activating features
act_times = mongo_client.get_feature_act_times(settings.sae_name, settings.sae_series, settings.analysis_name)
if not act_times:
raise ValueError(f"No feature activation times found for {settings.sae_name}/{settings.sae_series}")
sorted_features = sorted(act_times.items(), key=lambda x: x[1], reverse=True)
feature_indices = [idx for idx, _ in sorted_features[: settings.top_k_features]]
elif settings.feature_range:
# Use feature range
feature_indices = list(range(settings.feature_range[0], settings.feature_range[1] + 1))
elif settings.features:
# Use specific features
feature_indices = settings.features
else:
# Use all features (be careful, this could be a lot!)
max_feature_acts = mongo_client.get_max_feature_acts(
settings.sae_name, settings.sae_series, settings.analysis_name
)
if not max_feature_acts:
raise ValueError(f"No feature activations found for {settings.sae_name}/{settings.sae_series}")
feature_indices = list(max_feature_acts.keys())

# Load resources
logger.info(f"Loading SAE model: {settings.sae_name}/{settings.sae_series}")
logger.info(f"Loading language model: {settings.model_name}")

chunk_size = len(feature_indices) // settings.max_workers + 1
feature_batches = [feature_indices[i : i + chunk_size] for i in range(0, len(feature_indices), chunk_size)]
args_batches = [{"feature_indices": feature_indices, "settings": settings} for feature_indices in feature_batches]

with concurrent.futures.ThreadPoolExecutor(max_workers=settings.max_workers) as executor:
list(executor.map(interpret_feature, args_batches))

logger.info("Done!")
asyncio.run(interpret_feature(settings))
Binary file modified ui/bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion ui/src/components/app/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const Sample = <T extends { token: Uint8Array }>({
const [folded, setFolded] = useState(true);

return (
<div className={cn(folded && foldedStart !== undefined && "cursor-pointer -m-1 p-1 rounded-lg hover:bg-gray-100")}>
<div className={cn(folded && foldedStart !== undefined && "-m-1 p-1 rounded-lg hover:bg-gray-100")}>
<div
className={cn(folded && foldedStart !== undefined && "line-clamp-3 pb-[1px]")}
onClick={foldedStart !== undefined && folded ? () => setFolded(!folded) : undefined}
Expand Down
2 changes: 1 addition & 1 deletion ui/src/components/feature/interpret.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ export const FeatureInterpretation = ({ feature }: { feature: Feature }) => {
</div>
</div>
<div className="flex flex-col gap-4 basis-1/3 min-w-1/3">
{interpretation?.validation.map((validation, i) => (
{interpretation?.validation?.map((validation, i) => (
<div key={i} className="flex items-center gap-2">
{validation.passed ? (
<Check size={20} className="text-green-500" />
Expand Down
2 changes: 1 addition & 1 deletion ui/src/types/feature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export const InterpretationSchema = z.object({
})
.optional(),
})
),
).optional(),
detail: z
.object({
userPrompt: z.string(),
Expand Down
Loading