Skip to content

Commit 651fa72

Browse files
authored
refactor: use tanstack start for frontend; make a more neuronpedia-like ui (#146)
1 parent cd0f71e commit 651fa72

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+5078
-151
lines changed

server/app.py

Lines changed: 68 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import io
21
import json
32
import os
4-
from functools import lru_cache
3+
import threading
4+
from functools import lru_cache, wraps
55
from typing import Any, Generator, Optional
66

77
import msgpack
@@ -31,33 +31,41 @@
3131

3232
app.add_middleware(GZipMiddleware, minimum_size=1000)
3333

34-
print(os.environ.get("MONGO_URI", "mongodb://localhost:27017/"))
35-
print(os.environ.get("MONGO_DB", "mechinterp"))
3634
client = MongoClient(MongoDBConfig())
3735
sae_series = os.environ.get("SAE_SERIES", "default")
3836
tokenizer_only = os.environ.get("TOKENIZER_ONLY", "false").lower() == "true"
39-
if tokenizer_only:
40-
print("WARNING: Tokenizer only mode is enabled, some features may not be available")
4137

42-
# Remove global caches in favor of LRU cache
43-
# sae_cache: dict[str, SparseAutoEncoder] = {}
44-
# lm_cache: dict[str, LanguageModel] = {}
45-
# dataset_cache: dict[tuple[str, int, int], Dataset] = {}
4638

39+
def synchronized(func):
40+
"""Decorator to ensure sequential execution of a function based on parameters.
4741
48-
@lru_cache(maxsize=8)
49-
def get_model(name: str) -> LanguageModel:
50-
"""Load and cache a language model.
42+
Different parameters can be acquired in parallel, but the same parameters
43+
will be executed sequentially.
44+
"""
45+
locks: dict[frozenset[tuple[str, Any]], threading.Lock] = {}
46+
global_lock = threading.Lock()
5147

52-
Args:
53-
name: Name of the model to load
48+
@wraps(func)
49+
def wrapper(*args, **kwargs):
50+
assert len(args) == 0, "Positional arguments are not supported"
51+
key = frozenset(kwargs.items())
5452

55-
Returns:
56-
LanguageModel: The loaded model
53+
# The lock creation is locked by the global lock to avoid race conditions on locks.
54+
with global_lock:
55+
if key not in locks:
56+
locks[key] = threading.Lock()
57+
lock = locks[key]
5758

58-
Raises:
59-
ValueError: If the model is not found
60-
"""
59+
with lock:
60+
return func(*args, **kwargs)
61+
62+
return wrapper
63+
64+
65+
@lru_cache(maxsize=8)
66+
@synchronized
67+
def get_model(*, name: str) -> LanguageModel:
68+
"""Load and cache a language model."""
6169
cfg = client.get_model_cfg(name)
6270
if cfg is None:
6371
raise ValueError(f"Model {name} not found")
@@ -66,38 +74,18 @@ def get_model(name: str) -> LanguageModel:
6674

6775

6876
@lru_cache(maxsize=16)
69-
def get_dataset(name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
70-
"""Load and cache a dataset shard.
71-
72-
Args:
73-
name: Name of the dataset
74-
shard_idx: Index of the shard to load
75-
n_shards: Total number of shards
76-
77-
Returns:
78-
Dataset: The loaded dataset shard
79-
80-
Raises:
81-
AssertionError: If the dataset is not found
82-
"""
77+
@synchronized
78+
def get_dataset(*, name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
79+
"""Load and cache a dataset shard."""
8380
cfg = client.get_dataset_cfg(name)
8481
assert cfg is not None, f"Dataset {name} not found"
8582
return load_dataset_shard(cfg, shard_idx, n_shards)
8683

8784

8885
@lru_cache(maxsize=8)
89-
def get_sae(name: str) -> SparseAutoEncoder:
90-
"""Load and cache a sparse autoencoder.
91-
92-
Args:
93-
name: Name of the SAE to load
94-
95-
Returns:
96-
SparseAutoEncoder: The loaded SAE
97-
98-
Raises:
99-
AssertionError: If the SAE is not found
100-
"""
86+
@synchronized
87+
def get_sae(*, name: str) -> SparseAutoEncoder:
88+
"""Load and cache a sparse autoencoder."""
10189
path = client.get_sae_path(name, sae_series)
10290
assert path is not None, f"SAE {name} not found"
10391
cfg = SAEConfig.from_pretrained(path)
@@ -161,29 +149,6 @@ def list_dictionaries():
161149
return client.list_saes(sae_series=sae_series, has_analyses=True)
162150

163151

164-
@app.get("/images/{dataset_name}")
165-
def get_image(dataset_name: str, context_idx: int, image_idx: int, shard_idx: int = 0, n_shards: int = 1):
166-
assert transforms is not None, "torchvision not found, image processing will be disabled"
167-
dataset = get_dataset(dataset_name, shard_idx, n_shards)
168-
data = dataset[int(context_idx)]
169-
170-
image_key = next((key for key in ["image", "images"] if key in data), None)
171-
if image_key is None:
172-
return Response(content="Image not found", status_code=404)
173-
174-
if len(data[image_key]) <= image_idx:
175-
return Response(content="Image not found", status_code=404)
176-
177-
image_tensor = data[image_key][image_idx]
178-
179-
# Convert tensor to PIL Image and then to bytes
180-
image = transforms.ToPILImage()(image_tensor.to(torch.uint8))
181-
img_byte_arr = io.BytesIO()
182-
image.save(img_byte_arr, format="PNG")
183-
184-
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
185-
186-
187152
@app.get("/dictionaries/{name}/metrics")
188153
def get_available_metrics(name: str):
189154
"""Get available metrics for a dictionary.
@@ -311,9 +276,9 @@ def process_sample(
311276
Returns:
312277
dict: Processed sample data
313278
""" # Get model and dataset
314-
model = get_model(model_name)
279+
model = get_model(name=model_name)
315280
# model = None
316-
data = get_dataset(dataset_name, shard_idx, n_shards)[context_idx.item()]
281+
data = get_dataset(name=dataset_name, shard_idx=shard_idx, n_shards=n_shards)[context_idx.item()]
317282

318283
# Get origins for the features
319284
origins = model.trace({k: [v] for k, v in data.items()})[0]
@@ -477,84 +442,6 @@ def process_sparse_feature_acts(
477442
)
478443

479444

480-
# @app.post("/dictionaries/{name}/cache_features")
481-
# def cache_features(
482-
# name: str,
483-
# features: list[dict[str, Any]] = Body(..., embed=True),
484-
# output_dir: str = Body(...),
485-
# ):
486-
# """Batch-fetch and persist feature payloads for offline reuse.
487-
488-
# Args:
489-
# name: Dictionary/SAE name.
490-
# features: List of feature specs currently on screen. Each item should contain
491-
# - feature_id: int
492-
# - layer: int
493-
# - is_lorsa: bool
494-
# - analysis_name: Optional[str] (overrides auto selection)
495-
# output_dir: Directory on the server filesystem to write files into.
496-
497-
# Returns:
498-
# Dict with count and directory path.
499-
# """
500-
# os.makedirs(output_dir, exist_ok=True)
501-
502-
# saved = 0
503-
# for f in features:
504-
# feature_id = int(f["feature_id"]) # may raise KeyError which FastAPI will surface
505-
# layer = int(f["layer"]) # required for formatting analysis name
506-
# is_lorsa = bool(f.get("is_lorsa", False))
507-
# analysis_name_override = f.get("analysis_name")
508-
509-
# # Determine analysis name for this feature
510-
# formatted_analysis_name: str | None = None
511-
# if analysis_name_override is not None:
512-
# formatted_analysis_name = analysis_name_override
513-
# else:
514-
# try:
515-
# base_name = (
516-
# client.get_lorsa_analysis_name(name, sae_series)
517-
# if is_lorsa
518-
# else client.get_clt_analysis_name(name, sae_series)
519-
# )
520-
# except AttributeError:
521-
# base_name = None
522-
# if base_name is None:
523-
# feat = client.get_random_alive_feature(sae_name=name, sae_series=sae_series)
524-
# if feat is None:
525-
# return Response(content=f"Dictionary {name} not found", status_code=404)
526-
# available = [a.name for a in feat.analyses]
527-
# preferred = [a for a in available if ("lorsa" in a) == is_lorsa]
528-
# base_name = preferred[0] if preferred else available[0]
529-
# formatted_analysis_name = base_name.replace("{}", str(layer))
530-
531-
# # Reuse existing single-feature endpoint logic. Align with frontend usage where
532-
# # the path 'name' is the formatted analysis name used by GET /dictionaries/{name}/features/{id}.
533-
# res = get_feature(name=formatted_analysis_name, feature_index=feature_id, feature_analysis_name=None)
534-
# if isinstance(res, Response) and res.status_code != 200:
535-
# # Skip but continue
536-
# continue
537-
538-
# payload = res.body if isinstance(res, Response) else res
539-
# # Write as msgpack for fidelity and also a JSON alongside for convenience
540-
# base = os.path.join(output_dir, f"layer{layer}__feature{feature_id}__{formatted_analysis_name}.msgpack")
541-
# with open(base, "wb") as fbin:
542-
# fbin.write(payload)
543-
# try:
544-
# decoded = msgpack.unpackb(payload, raw=False)
545-
# json_path = base.replace(".msgpack", ".json")
546-
# # make_serializable handles tensors/np arrays
547-
# import json as _json
548-
549-
# with open(json_path, "w") as fj:
550-
# _json.dump(make_serializable(decoded), fj)
551-
# except Exception:
552-
# pass
553-
# saved += 1
554-
555-
# return {"saved": saved, "output_dir": output_dir}
556-
557-
558445
@app.get("/dictionaries/{name}")
559446
def get_dictionary(name: str):
560447
# Get feature activation times
@@ -610,6 +497,38 @@ def get_analyses(name: str):
610497
return analyses
611498

612499

500+
@app.post("/dictionaries/{name}/features/{feature_index}/infer")
501+
def infer_feature(name: str, feature_index: int, text: str):
502+
"""Infer feature activations for a given text."""
503+
model_name = client.get_sae_model_name(name, sae_series)
504+
assert model_name is not None, f"SAE {name} not found or no model name is associated with it"
505+
model = get_model(name=model_name)
506+
sae = get_sae(name=name)
507+
508+
activations = model.to_activations({"text": [text]}, hook_points=sae.cfg.associated_hook_points)
509+
activations = sae.normalize_activations(activations)
510+
x, encoder_kwargs, _ = sae.prepare_input(activations)
511+
feature_acts = sae.encode(x, **encoder_kwargs)[0, :, feature_index]
512+
513+
feature_acts = feature_acts.to_sparse()
514+
515+
origins = model.trace({"text": [text]})[0]
516+
517+
return Response(
518+
content=msgpack.packb(
519+
make_serializable(
520+
{
521+
"text": text,
522+
"origins": origins,
523+
"feature_acts_indices": feature_acts.indices()[0],
524+
"feature_acts_values": feature_acts.values(),
525+
}
526+
)
527+
),
528+
media_type="application/x-msgpack",
529+
)
530+
531+
613532
@app.post("/dictionaries/{name}/features/{feature_index}/bookmark")
614533
def add_bookmark(name: str, feature_index: int):
615534
"""Add a bookmark for a feature.

src/lm_saes/database.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def _remove_gridfs_objs(self, data: Any) -> None:
191191
if isinstance(data, ObjectId) and self.fs.exists(data):
192192
self.fs.delete(data)
193193

194-
def create_sae(self, name: str, series: str, path: str, cfg: BaseSAEConfig):
194+
def create_sae(self, name: str, series: str, path: str, cfg: BaseSAEConfig, model_name: str | None = None):
195195
inserted_id = self.sae_collection.insert_one(
196-
{"name": name, "series": series, "path": path, "cfg": cfg.model_dump()}
196+
{"name": name, "series": series, "path": path, "cfg": cfg.model_dump(), "model_name": model_name}
197197
).inserted_id
198198
self.feature_collection.insert_many(
199199
[{"sae_name": name, "sae_series": series, "index": i} for i in range(cfg.d_sae)]
@@ -364,6 +364,9 @@ def get_sae_path(self, sae_name: str, sae_series: str):
364364
return None
365365
return sae["path"]
366366

367+
def get_sae_model_name(self, sae_name: str, sae_series: str) -> Optional[str]:
368+
return self.sae_collection.find_one({"name": sae_name, "series": sae_series})["model_name"]
369+
367370
def add_dataset(self, name: str, cfg: DatasetConfig):
368371
self.dataset_collection.update_one({"name": name}, {"$set": {"cfg": cfg.model_dump()}}, upsert=True)
369372

ui-ssr/.cta.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"projectName": "ui-ssr",
3+
"mode": "file-router",
4+
"typescript": true,
5+
"tailwind": true,
6+
"packageManager": "bun",
7+
"addOnOptions": {},
8+
"git": true,
9+
"version": 1,
10+
"framework": "react-cra",
11+
"chosenAddOns": [
12+
"eslint",
13+
"nitro",
14+
"start",
15+
"compiler",
16+
"shadcn",
17+
"table",
18+
"store",
19+
"tanstack-query"
20+
]
21+
}

ui-ssr/.cursorrules

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# shadcn instructions
2+
3+
Use the latest version of Shadcn to install new components, like this command to add a button component:
4+
5+
```bash
6+
pnpx shadcn@latest add button
7+
```

ui-ssr/.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
node_modules
2+
.DS_Store
3+
dist
4+
dist-ssr
5+
*.local
6+
.env
7+
.nitro
8+
.tanstack
9+
.wrangler
10+
.output
11+
.vinxi

ui-ssr/.prettierignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package-lock.json
2+
pnpm-lock.yaml
3+
yarn.lock

ui-ssr/.vscode/settings.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"files.watcherExclude": {
3+
"**/routeTree.gen.ts": true
4+
},
5+
"search.exclude": {
6+
"**/routeTree.gen.ts": true
7+
},
8+
"files.readonlyInclude": {
9+
"**/routeTree.gen.ts": true
10+
}
11+
}

0 commit comments

Comments
 (0)