Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 68 additions & 149 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import json
import os
from functools import lru_cache
import threading
from functools import lru_cache, wraps
from typing import Any, Generator, Optional

import msgpack
Expand Down Expand Up @@ -31,33 +31,41 @@

app.add_middleware(GZipMiddleware, minimum_size=1000)

print(os.environ.get("MONGO_URI", "mongodb://localhost:27017/"))
print(os.environ.get("MONGO_DB", "mechinterp"))
client = MongoClient(MongoDBConfig())
sae_series = os.environ.get("SAE_SERIES", "default")
tokenizer_only = os.environ.get("TOKENIZER_ONLY", "false").lower() == "true"
if tokenizer_only:
print("WARNING: Tokenizer only mode is enabled, some features may not be available")

# Remove global caches in favor of LRU cache
# sae_cache: dict[str, SparseAutoEncoder] = {}
# lm_cache: dict[str, LanguageModel] = {}
# dataset_cache: dict[tuple[str, int, int], Dataset] = {}

def synchronized(func):
"""Decorator to ensure sequential execution of a function based on parameters.

@lru_cache(maxsize=8)
def get_model(name: str) -> LanguageModel:
"""Load and cache a language model.
Different parameters can be acquired in parallel, but the same parameters
will be executed sequentially.
"""
locks: dict[frozenset[tuple[str, Any]], threading.Lock] = {}
global_lock = threading.Lock()

Args:
name: Name of the model to load
@wraps(func)
def wrapper(*args, **kwargs):
assert len(args) == 0, "Positional arguments are not supported"
key = frozenset(kwargs.items())

Returns:
LanguageModel: The loaded model
# The lock creation is locked by the global lock to avoid race conditions on locks.
with global_lock:
if key not in locks:
locks[key] = threading.Lock()
lock = locks[key]

Raises:
ValueError: If the model is not found
"""
with lock:
return func(*args, **kwargs)

return wrapper


@lru_cache(maxsize=8)
@synchronized
def get_model(*, name: str) -> LanguageModel:
"""Load and cache a language model."""
cfg = client.get_model_cfg(name)
if cfg is None:
raise ValueError(f"Model {name} not found")
Expand All @@ -66,38 +74,18 @@ def get_model(name: str) -> LanguageModel:


@lru_cache(maxsize=16)
def get_dataset(name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
"""Load and cache a dataset shard.

Args:
name: Name of the dataset
shard_idx: Index of the shard to load
n_shards: Total number of shards

Returns:
Dataset: The loaded dataset shard

Raises:
AssertionError: If the dataset is not found
"""
@synchronized
def get_dataset(*, name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
"""Load and cache a dataset shard."""
cfg = client.get_dataset_cfg(name)
assert cfg is not None, f"Dataset {name} not found"
return load_dataset_shard(cfg, shard_idx, n_shards)


@lru_cache(maxsize=8)
def get_sae(name: str) -> SparseAutoEncoder:
"""Load and cache a sparse autoencoder.

Args:
name: Name of the SAE to load

Returns:
SparseAutoEncoder: The loaded SAE

Raises:
AssertionError: If the SAE is not found
"""
@synchronized
def get_sae(*, name: str) -> SparseAutoEncoder:
"""Load and cache a sparse autoencoder."""
path = client.get_sae_path(name, sae_series)
assert path is not None, f"SAE {name} not found"
cfg = SAEConfig.from_pretrained(path)
Expand Down Expand Up @@ -161,29 +149,6 @@ def list_dictionaries():
return client.list_saes(sae_series=sae_series, has_analyses=True)


@app.get("/images/{dataset_name}")
def get_image(dataset_name: str, context_idx: int, image_idx: int, shard_idx: int = 0, n_shards: int = 1):
assert transforms is not None, "torchvision not found, image processing will be disabled"
dataset = get_dataset(dataset_name, shard_idx, n_shards)
data = dataset[int(context_idx)]

image_key = next((key for key in ["image", "images"] if key in data), None)
if image_key is None:
return Response(content="Image not found", status_code=404)

if len(data[image_key]) <= image_idx:
return Response(content="Image not found", status_code=404)

image_tensor = data[image_key][image_idx]

# Convert tensor to PIL Image and then to bytes
image = transforms.ToPILImage()(image_tensor.to(torch.uint8))
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")

return Response(content=img_byte_arr.getvalue(), media_type="image/png")


@app.get("/dictionaries/{name}/metrics")
def get_available_metrics(name: str):
"""Get available metrics for a dictionary.
Expand Down Expand Up @@ -311,9 +276,9 @@ def process_sample(
Returns:
dict: Processed sample data
""" # Get model and dataset
model = get_model(model_name)
model = get_model(name=model_name)
# model = None
data = get_dataset(dataset_name, shard_idx, n_shards)[context_idx.item()]
data = get_dataset(name=dataset_name, shard_idx=shard_idx, n_shards=n_shards)[context_idx.item()]

# Get origins for the features
origins = model.trace({k: [v] for k, v in data.items()})[0]
Expand Down Expand Up @@ -477,84 +442,6 @@ def process_sparse_feature_acts(
)


# @app.post("/dictionaries/{name}/cache_features")
# def cache_features(
# name: str,
# features: list[dict[str, Any]] = Body(..., embed=True),
# output_dir: str = Body(...),
# ):
# """Batch-fetch and persist feature payloads for offline reuse.

# Args:
# name: Dictionary/SAE name.
# features: List of feature specs currently on screen. Each item should contain
# - feature_id: int
# - layer: int
# - is_lorsa: bool
# - analysis_name: Optional[str] (overrides auto selection)
# output_dir: Directory on the server filesystem to write files into.

# Returns:
# Dict with count and directory path.
# """
# os.makedirs(output_dir, exist_ok=True)

# saved = 0
# for f in features:
# feature_id = int(f["feature_id"]) # may raise KeyError which FastAPI will surface
# layer = int(f["layer"]) # required for formatting analysis name
# is_lorsa = bool(f.get("is_lorsa", False))
# analysis_name_override = f.get("analysis_name")

# # Determine analysis name for this feature
# formatted_analysis_name: str | None = None
# if analysis_name_override is not None:
# formatted_analysis_name = analysis_name_override
# else:
# try:
# base_name = (
# client.get_lorsa_analysis_name(name, sae_series)
# if is_lorsa
# else client.get_clt_analysis_name(name, sae_series)
# )
# except AttributeError:
# base_name = None
# if base_name is None:
# feat = client.get_random_alive_feature(sae_name=name, sae_series=sae_series)
# if feat is None:
# return Response(content=f"Dictionary {name} not found", status_code=404)
# available = [a.name for a in feat.analyses]
# preferred = [a for a in available if ("lorsa" in a) == is_lorsa]
# base_name = preferred[0] if preferred else available[0]
# formatted_analysis_name = base_name.replace("{}", str(layer))

# # Reuse existing single-feature endpoint logic. Align with frontend usage where
# # the path 'name' is the formatted analysis name used by GET /dictionaries/{name}/features/{id}.
# res = get_feature(name=formatted_analysis_name, feature_index=feature_id, feature_analysis_name=None)
# if isinstance(res, Response) and res.status_code != 200:
# # Skip but continue
# continue

# payload = res.body if isinstance(res, Response) else res
# # Write as msgpack for fidelity and also a JSON alongside for convenience
# base = os.path.join(output_dir, f"layer{layer}__feature{feature_id}__{formatted_analysis_name}.msgpack")
# with open(base, "wb") as fbin:
# fbin.write(payload)
# try:
# decoded = msgpack.unpackb(payload, raw=False)
# json_path = base.replace(".msgpack", ".json")
# # make_serializable handles tensors/np arrays
# import json as _json

# with open(json_path, "w") as fj:
# _json.dump(make_serializable(decoded), fj)
# except Exception:
# pass
# saved += 1

# return {"saved": saved, "output_dir": output_dir}


@app.get("/dictionaries/{name}")
def get_dictionary(name: str):
# Get feature activation times
Expand Down Expand Up @@ -610,6 +497,38 @@ def get_analyses(name: str):
return analyses


@app.post("/dictionaries/{name}/features/{feature_index}/infer")
def infer_feature(name: str, feature_index: int, text: str):
"""Infer feature activations for a given text."""
model_name = client.get_sae_model_name(name, sae_series)
assert model_name is not None, f"SAE {name} not found or no model name is associated with it"
model = get_model(name=model_name)
sae = get_sae(name=name)

activations = model.to_activations({"text": [text]}, hook_points=sae.cfg.associated_hook_points)
activations = sae.normalize_activations(activations)
x, encoder_kwargs, _ = sae.prepare_input(activations)
feature_acts = sae.encode(x, **encoder_kwargs)[0, :, feature_index]

feature_acts = feature_acts.to_sparse()

origins = model.trace({"text": [text]})[0]

return Response(
content=msgpack.packb(
make_serializable(
{
"text": text,
"origins": origins,
"feature_acts_indices": feature_acts.indices()[0],
"feature_acts_values": feature_acts.values(),
}
)
),
media_type="application/x-msgpack",
)


@app.post("/dictionaries/{name}/features/{feature_index}/bookmark")
def add_bookmark(name: str, feature_index: int):
"""Add a bookmark for a feature.
Expand Down
7 changes: 5 additions & 2 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@
if isinstance(data, ObjectId) and self.fs.exists(data):
self.fs.delete(data)

def create_sae(self, name: str, series: str, path: str, cfg: BaseSAEConfig):
def create_sae(self, name: str, series: str, path: str, cfg: BaseSAEConfig, model_name: str | None = None):
inserted_id = self.sae_collection.insert_one(
{"name": name, "series": series, "path": path, "cfg": cfg.model_dump()}
{"name": name, "series": series, "path": path, "cfg": cfg.model_dump(), "model_name": model_name}
).inserted_id
self.feature_collection.insert_many(
[{"sae_name": name, "sae_series": series, "index": i} for i in range(cfg.d_sae)]
Expand Down Expand Up @@ -363,6 +363,9 @@
return None
return sae["path"]

def get_sae_model_name(self, sae_name: str, sae_series: str) -> Optional[str]:
return self.sae_collection.find_one({"name": sae_name, "series": sae_series})["model_name"]

Check failure on line 368 in src/lm_saes/database.py

View workflow job for this annotation

GitHub Actions / Type Checks

Object of type "None" is not subscriptable (reportOptionalSubscript)
def add_dataset(self, name: str, cfg: DatasetConfig):
self.dataset_collection.update_one({"name": name}, {"$set": {"cfg": cfg.model_dump()}}, upsert=True)

Expand Down
21 changes: 21 additions & 0 deletions ui-ssr/.cta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"projectName": "ui-ssr",
"mode": "file-router",
"typescript": true,
"tailwind": true,
"packageManager": "bun",
"addOnOptions": {},
"git": true,
"version": 1,
"framework": "react-cra",
"chosenAddOns": [
"eslint",
"nitro",
"start",
"compiler",
"shadcn",
"table",
"store",
"tanstack-query"
]
}
7 changes: 7 additions & 0 deletions ui-ssr/.cursorrules
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# shadcn instructions

Use the latest version of Shadcn to install new components, like this command to add a button component:

```bash
pnpx shadcn@latest add button
```
11 changes: 11 additions & 0 deletions ui-ssr/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
node_modules
.DS_Store
dist
dist-ssr
*.local
.env
.nitro
.tanstack
.wrangler
.output
.vinxi
3 changes: 3 additions & 0 deletions ui-ssr/.prettierignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package-lock.json
pnpm-lock.yaml
yarn.lock
11 changes: 11 additions & 0 deletions ui-ssr/.vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"files.watcherExclude": {
"**/routeTree.gen.ts": true
},
"search.exclude": {
"**/routeTree.gen.ts": true
},
"files.readonlyInclude": {
"**/routeTree.gen.ts": true
}
}
Loading
Loading