Skip to content

Commit 2f3d04d

Browse files
dest1n1sFrankstein73
authored andcommitted
feat(api): add metric filtering and retrieval for features
- Introduced new endpoints to retrieve available metrics and count features based on specified filters. - Enhanced existing feature retrieval to support metric filters, allowing for more granular data access. - Updated the `FeatureRecord` model to include an optional `metric` field for better data representation. - Implemented frontend logic to manage metric filters and display filtering options in the UI.
1 parent deb5025 commit 2f3d04d

File tree

3 files changed

+343
-21
lines changed

3 files changed

+343
-21
lines changed

server/app.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import os
34
from functools import lru_cache
45
from typing import Any, Optional
@@ -166,11 +167,60 @@ def get_image(dataset_name: str, context_idx: int, image_idx: int, shard_idx: in
166167
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
167168

168169

170+
@app.get("/dictionaries/{name}/metrics")
171+
def get_available_metrics(name: str):
172+
"""Get available metrics for a dictionary.
173+
174+
Args:
175+
name: Name of the dictionary/SAE
176+
177+
Returns:
178+
List of available metric names
179+
"""
180+
metrics = client.get_available_metrics(name, sae_series=sae_series)
181+
return {"metrics": metrics}
182+
183+
184+
@app.get("/dictionaries/{name}/features/count")
185+
def count_features_with_filters(
186+
name: str,
187+
feature_analysis_name: str | None = None,
188+
metric_filters: str | None = None,
189+
):
190+
"""Count features that match the given filters.
191+
192+
Args:
193+
name: Name of the dictionary/SAE
194+
feature_analysis_name: Optional analysis name
195+
metric_filters: Optional JSON string of metric filters
196+
197+
Returns:
198+
Count of features matching the filters
199+
"""
200+
# Parse metric filters if provided
201+
parsed_metric_filters = None
202+
if metric_filters:
203+
try:
204+
parsed_metric_filters = json.loads(metric_filters)
205+
except (json.JSONDecodeError, TypeError):
206+
return Response(
207+
content=f"Invalid metric_filters format: {metric_filters}",
208+
status_code=400,
209+
)
210+
211+
count = client.count_features_with_filters(
212+
sae_name=name, sae_series=sae_series, name=feature_analysis_name, metric_filters=parsed_metric_filters
213+
)
214+
215+
return {"count": count}
216+
217+
169218
@app.get("/dictionaries/{name}/features/{feature_index}")
170219
def get_feature(
171220
name: str,
172221
feature_index: str | int,
173222
feature_analysis_name: str | None = None,
223+
metric_filters: str | None = None,
174224
):
175225
# Parse feature_index if it's a string
176226
if isinstance(feature_index, str) and feature_index != "random":
@@ -182,9 +232,22 @@ def get_feature(
182232
status_code=400,
183233
)
184234

235+
# Parse metric filters if provided
236+
parsed_metric_filters = None
237+
if metric_filters:
238+
try:
239+
parsed_metric_filters = json.loads(metric_filters)
240+
except (json.JSONDecodeError, TypeError):
241+
return Response(
242+
content=f"Invalid metric_filters format: {metric_filters}",
243+
status_code=400,
244+
)
245+
185246
# Get feature data
186247
feature = (
187-
client.get_random_alive_feature(sae_name=name, sae_series=sae_series, name=feature_analysis_name)
248+
client.get_random_alive_feature(
249+
sae_name=name, sae_series=sae_series, name=feature_analysis_name, metric_filters=parsed_metric_filters
250+
)
188251
if feature_index == "random"
189252
else client.get_feature(sae_name=name, sae_series=sae_series, index=feature_index)
190253
)

src/lm_saes/database.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class FeatureRecord(BaseModel):
5555
index: int
5656
analyses: list[FeatureAnalysis] = []
5757
interpretation: Optional[dict[str, Any]] = None
58+
metric: Optional[dict[str, float]] = None
5859

5960

6061
class AnalysisRecord(BaseModel):
@@ -235,14 +236,19 @@ def get_sae(self, sae_name: str, sae_series: str) -> Optional[SAERecord]:
235236
return SAERecord.model_validate(sae)
236237

237238
def get_random_alive_feature(
238-
self, sae_name: str, sae_series: str, name: str | None = None
239+
self,
240+
sae_name: str,
241+
sae_series: str,
242+
name: str | None = None,
243+
metric_filters: Optional[dict[str, dict[str, float]]] = None,
239244
) -> Optional[FeatureRecord]:
240245
"""Get a random feature that has non-zero activation.
241246
242247
Args:
243248
sae_name: Name of the SAE model
244249
sae_series: Series of the SAE model
245250
name: Name of the analysis
251+
metric_filters: Optional dict of metric filters in the format {"metric_name": {"$gte": value, "$lte": value}}
246252
247253
Returns:
248254
A random feature record with non-zero activation, or None if no such feature exists
@@ -251,14 +257,19 @@ def get_random_alive_feature(
251257
if name is not None:
252258
elem_match["name"] = name
253259

260+
match_filter: dict[str, Any] = {
261+
"sae_name": sae_name,
262+
"sae_series": sae_series,
263+
"analyses": {"$elemMatch": elem_match},
264+
}
265+
266+
# Add metric filters if provided
267+
if metric_filters:
268+
for metric_name, filters in metric_filters.items():
269+
match_filter[f"metric.{metric_name}"] = filters
270+
254271
pipeline = [
255-
{
256-
"$match": {
257-
"sae_name": sae_name,
258-
"sae_series": sae_series,
259-
"analyses": {"$elemMatch": elem_match},
260-
}
261-
},
272+
{"$match": match_filter},
262273
{"$sample": {"size": 1}},
263274
]
264275
feature = next(self.feature_collection.aggregate(pipeline), None)
@@ -590,3 +601,60 @@ def get_bookmark_count(self, sae_name: Optional[str] = None, sae_series: Optiona
590601
query["sae_series"] = sae_series
591602

592603
return self.bookmark_collection.count_documents(query)
604+
605+
def get_available_metrics(self, sae_name: str, sae_series: str) -> list[str]:
606+
"""Get available metrics for an SAE by checking the first feature.
607+
608+
Args:
609+
sae_name: Name of the SAE model
610+
sae_series: Series of the SAE model
611+
612+
Returns:
613+
List of available metric names
614+
"""
615+
# Use projection to avoid loading large arrays from analyses[0].samplings
616+
projection = {
617+
"metric": 1,
618+
}
619+
620+
first_feature = self.feature_collection.find_one({"sae_name": sae_name, "sae_series": sae_series}, projection)
621+
622+
if first_feature is None or first_feature.get("metric") is None:
623+
return []
624+
625+
return list(first_feature["metric"].keys())
626+
627+
def count_features_with_filters(
628+
self,
629+
sae_name: str,
630+
sae_series: str,
631+
name: str | None = None,
632+
metric_filters: Optional[dict[str, dict[str, float]]] = None,
633+
) -> int:
634+
"""Count features that match the given filters.
635+
636+
Args:
637+
sae_name: Name of the SAE model
638+
sae_series: Series of the SAE model
639+
name: Name of the analysis
640+
metric_filters: Optional dict of metric filters in the format {"metric_name": {"$gte": value, "$lte": value}}
641+
642+
Returns:
643+
Number of features matching the filters
644+
"""
645+
elem_match: dict[str, Any] = {"max_feature_acts": {"$gt": 0}}
646+
if name is not None:
647+
elem_match["name"] = name
648+
649+
match_filter: dict[str, Any] = {
650+
"sae_name": sae_name,
651+
"sae_series": sae_series,
652+
"analyses": {"$elemMatch": elem_match},
653+
}
654+
655+
# Add metric filters if provided
656+
if metric_filters:
657+
for metric_name, filters in metric_filters.items():
658+
match_filter[f"metric.{metric_name}"] = filters
659+
660+
return self.feature_collection.count_documents(match_filter)

0 commit comments

Comments
 (0)