1- import io
21import json
32import os
4- from functools import lru_cache
3+ import threading
4+ from functools import lru_cache , wraps
55from typing import Any , Generator , Optional
66
77import msgpack
3131
3232app .add_middleware (GZipMiddleware , minimum_size = 1000 )
3333
34- print (os .environ .get ("MONGO_URI" , "mongodb://localhost:27017/" ))
35- print (os .environ .get ("MONGO_DB" , "mechinterp" ))
3634client = MongoClient (MongoDBConfig ())
3735sae_series = os .environ .get ("SAE_SERIES" , "default" )
3836tokenizer_only = os .environ .get ("TOKENIZER_ONLY" , "false" ).lower () == "true"
39- if tokenizer_only :
40- print ("WARNING: Tokenizer only mode is enabled, some features may not be available" )
4137
42- # Remove global caches in favor of LRU cache
43- # sae_cache: dict[str, SparseAutoEncoder] = {}
44- # lm_cache: dict[str, LanguageModel] = {}
45- # dataset_cache: dict[tuple[str, int, int], Dataset] = {}
4638
39+ def synchronized (func ):
40+ """Decorator to ensure sequential execution of a function based on parameters.
4741
48- @lru_cache (maxsize = 8 )
49- def get_model (name : str ) -> LanguageModel :
50- """Load and cache a language model.
42+ Different parameters can be acquired in parallel, but the same parameters
43+ will be executed sequentially.
44+ """
45+ locks : dict [frozenset [tuple [str , Any ]], threading .Lock ] = {}
46+ global_lock = threading .Lock ()
5147
52- Args:
53- name: Name of the model to load
48+ @wraps (func )
49+ def wrapper (* args , ** kwargs ):
50+ assert len (args ) == 0 , "Positional arguments are not supported"
51+ key = frozenset (kwargs .items ())
5452
55- Returns:
56- LanguageModel: The loaded model
53+ # The lock creation is locked by the global lock to avoid race conditions on locks.
54+ with global_lock :
55+ if key not in locks :
56+ locks [key ] = threading .Lock ()
57+ lock = locks [key ]
5758
58- Raises:
59- ValueError: If the model is not found
60- """
59+ with lock :
60+ return func (* args , ** kwargs )
61+
62+ return wrapper
63+
64+
65+ @lru_cache (maxsize = 8 )
66+ @synchronized
67+ def get_model (* , name : str ) -> LanguageModel :
68+ """Load and cache a language model."""
6169 cfg = client .get_model_cfg (name )
6270 if cfg is None :
6371 raise ValueError (f"Model { name } not found" )
@@ -66,38 +74,18 @@ def get_model(name: str) -> LanguageModel:
6674
6775
6876@lru_cache (maxsize = 16 )
69- def get_dataset (name : str , shard_idx : int = 0 , n_shards : int = 1 ) -> Dataset :
70- """Load and cache a dataset shard.
71-
72- Args:
73- name: Name of the dataset
74- shard_idx: Index of the shard to load
75- n_shards: Total number of shards
76-
77- Returns:
78- Dataset: The loaded dataset shard
79-
80- Raises:
81- AssertionError: If the dataset is not found
82- """
77+ @synchronized
78+ def get_dataset (* , name : str , shard_idx : int = 0 , n_shards : int = 1 ) -> Dataset :
79+ """Load and cache a dataset shard."""
8380 cfg = client .get_dataset_cfg (name )
8481 assert cfg is not None , f"Dataset { name } not found"
8582 return load_dataset_shard (cfg , shard_idx , n_shards )
8683
8784
8885@lru_cache (maxsize = 8 )
89- def get_sae (name : str ) -> SparseAutoEncoder :
90- """Load and cache a sparse autoencoder.
91-
92- Args:
93- name: Name of the SAE to load
94-
95- Returns:
96- SparseAutoEncoder: The loaded SAE
97-
98- Raises:
99- AssertionError: If the SAE is not found
100- """
86+ @synchronized
87+ def get_sae (* , name : str ) -> SparseAutoEncoder :
88+ """Load and cache a sparse autoencoder."""
10189 path = client .get_sae_path (name , sae_series )
10290 assert path is not None , f"SAE { name } not found"
10391 cfg = SAEConfig .from_pretrained (path )
@@ -161,29 +149,6 @@ def list_dictionaries():
161149 return client .list_saes (sae_series = sae_series , has_analyses = True )
162150
163151
164- @app .get ("/images/{dataset_name}" )
165- def get_image (dataset_name : str , context_idx : int , image_idx : int , shard_idx : int = 0 , n_shards : int = 1 ):
166- assert transforms is not None , "torchvision not found, image processing will be disabled"
167- dataset = get_dataset (dataset_name , shard_idx , n_shards )
168- data = dataset [int (context_idx )]
169-
170- image_key = next ((key for key in ["image" , "images" ] if key in data ), None )
171- if image_key is None :
172- return Response (content = "Image not found" , status_code = 404 )
173-
174- if len (data [image_key ]) <= image_idx :
175- return Response (content = "Image not found" , status_code = 404 )
176-
177- image_tensor = data [image_key ][image_idx ]
178-
179- # Convert tensor to PIL Image and then to bytes
180- image = transforms .ToPILImage ()(image_tensor .to (torch .uint8 ))
181- img_byte_arr = io .BytesIO ()
182- image .save (img_byte_arr , format = "PNG" )
183-
184- return Response (content = img_byte_arr .getvalue (), media_type = "image/png" )
185-
186-
187152@app .get ("/dictionaries/{name}/metrics" )
188153def get_available_metrics (name : str ):
189154 """Get available metrics for a dictionary.
@@ -311,9 +276,9 @@ def process_sample(
311276 Returns:
312277 dict: Processed sample data
313278 """ # Get model and dataset
314- model = get_model (model_name )
279+ model = get_model (name = model_name )
315280 # model = None
316- data = get_dataset (dataset_name , shard_idx , n_shards )[context_idx .item ()]
281+ data = get_dataset (name = dataset_name , shard_idx = shard_idx , n_shards = n_shards )[context_idx .item ()]
317282
318283 # Get origins for the features
319284 origins = model .trace ({k : [v ] for k , v in data .items ()})[0 ]
@@ -477,84 +442,6 @@ def process_sparse_feature_acts(
477442 )
478443
479444
480- # @app.post("/dictionaries/{name}/cache_features")
481- # def cache_features(
482- # name: str,
483- # features: list[dict[str, Any]] = Body(..., embed=True),
484- # output_dir: str = Body(...),
485- # ):
486- # """Batch-fetch and persist feature payloads for offline reuse.
487-
488- # Args:
489- # name: Dictionary/SAE name.
490- # features: List of feature specs currently on screen. Each item should contain
491- # - feature_id: int
492- # - layer: int
493- # - is_lorsa: bool
494- # - analysis_name: Optional[str] (overrides auto selection)
495- # output_dir: Directory on the server filesystem to write files into.
496-
497- # Returns:
498- # Dict with count and directory path.
499- # """
500- # os.makedirs(output_dir, exist_ok=True)
501-
502- # saved = 0
503- # for f in features:
504- # feature_id = int(f["feature_id"]) # may raise KeyError which FastAPI will surface
505- # layer = int(f["layer"]) # required for formatting analysis name
506- # is_lorsa = bool(f.get("is_lorsa", False))
507- # analysis_name_override = f.get("analysis_name")
508-
509- # # Determine analysis name for this feature
510- # formatted_analysis_name: str | None = None
511- # if analysis_name_override is not None:
512- # formatted_analysis_name = analysis_name_override
513- # else:
514- # try:
515- # base_name = (
516- # client.get_lorsa_analysis_name(name, sae_series)
517- # if is_lorsa
518- # else client.get_clt_analysis_name(name, sae_series)
519- # )
520- # except AttributeError:
521- # base_name = None
522- # if base_name is None:
523- # feat = client.get_random_alive_feature(sae_name=name, sae_series=sae_series)
524- # if feat is None:
525- # return Response(content=f"Dictionary {name} not found", status_code=404)
526- # available = [a.name for a in feat.analyses]
527- # preferred = [a for a in available if ("lorsa" in a) == is_lorsa]
528- # base_name = preferred[0] if preferred else available[0]
529- # formatted_analysis_name = base_name.replace("{}", str(layer))
530-
531- # # Reuse existing single-feature endpoint logic. Align with frontend usage where
532- # # the path 'name' is the formatted analysis name used by GET /dictionaries/{name}/features/{id}.
533- # res = get_feature(name=formatted_analysis_name, feature_index=feature_id, feature_analysis_name=None)
534- # if isinstance(res, Response) and res.status_code != 200:
535- # # Skip but continue
536- # continue
537-
538- # payload = res.body if isinstance(res, Response) else res
539- # # Write as msgpack for fidelity and also a JSON alongside for convenience
540- # base = os.path.join(output_dir, f"layer{layer}__feature{feature_id}__{formatted_analysis_name}.msgpack")
541- # with open(base, "wb") as fbin:
542- # fbin.write(payload)
543- # try:
544- # decoded = msgpack.unpackb(payload, raw=False)
545- # json_path = base.replace(".msgpack", ".json")
546- # # make_serializable handles tensors/np arrays
547- # import json as _json
548-
549- # with open(json_path, "w") as fj:
550- # _json.dump(make_serializable(decoded), fj)
551- # except Exception:
552- # pass
553- # saved += 1
554-
555- # return {"saved": saved, "output_dir": output_dir}
556-
557-
558445@app .get ("/dictionaries/{name}" )
559446def get_dictionary (name : str ):
560447 # Get feature activation times
@@ -610,6 +497,38 @@ def get_analyses(name: str):
610497 return analyses
611498
612499
500+ @app .post ("/dictionaries/{name}/features/{feature_index}/infer" )
501+ def infer_feature (name : str , feature_index : int , text : str ):
502+ """Infer feature activations for a given text."""
503+ model_name = client .get_sae_model_name (name , sae_series )
504+ assert model_name is not None , f"SAE { name } not found or no model name is associated with it"
505+ model = get_model (name = model_name )
506+ sae = get_sae (name = name )
507+
508+ activations = model .to_activations ({"text" : [text ]}, hook_points = sae .cfg .associated_hook_points )
509+ activations = sae .normalize_activations (activations )
510+ x , encoder_kwargs , _ = sae .prepare_input (activations )
511+ feature_acts = sae .encode (x , ** encoder_kwargs )[0 , :, feature_index ]
512+
513+ feature_acts = feature_acts .to_sparse ()
514+
515+ origins = model .trace ({"text" : [text ]})[0 ]
516+
517+ return Response (
518+ content = msgpack .packb (
519+ make_serializable (
520+ {
521+ "text" : text ,
522+ "origins" : origins ,
523+ "feature_acts_indices" : feature_acts .indices ()[0 ],
524+ "feature_acts_values" : feature_acts .values (),
525+ }
526+ )
527+ ),
528+ media_type = "application/x-msgpack" ,
529+ )
530+
531+
613532@app .post ("/dictionaries/{name}/features/{feature_index}/bookmark" )
614533def add_bookmark (name : str , feature_index : int ):
615534 """Add a bookmark for a feature.
0 commit comments