Skip to content

Commit 52cbedf

Browse files
fix(auth_checks.py): ensure if key has access to aliased model, it still works
1 parent b0c98ef commit 52cbedf

File tree

3 files changed

+245
-196
lines changed

3 files changed

+245
-196
lines changed

litellm/proxy/_new_secret_config.yaml

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ model_list:
33
litellm_params:
44
model: openai/fake
55
api_key: fake-key
6-
api_base: https://webhook.site/4feb0d46-4b23-468c-bf55-7008b5deb36d
6+
api_base: https://exampleopenaiendpoint-production.up.railway.app/
77
- model_name: gpt-5-mini
88
litellm_params:
99
model: azure/gpt-5-mini
@@ -14,13 +14,5 @@ model_list:
1414
model_info:
1515
mode: chat
1616

17-
litellm_settings:
18-
cache: true
19-
cache_params:
20-
type: redis
21-
ttl: 600
22-
supported_call_types: ["acompletion", "completion"]
23-
24-
model_group_settings:
25-
forward_client_headers_to_llm_api:
26-
- fake-openai-endpoint
17+
router_settings:
18+
model_group_alias: {"my-fake-gpt-4": "fake-openai-endpoint"}

litellm/proxy/auth/auth_checks.py

Lines changed: 86 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -513,27 +513,27 @@ async def get_team_membership(
513513
) -> Optional["LiteLLM_TeamMembership"]:
514514
"""
515515
Returns team membership object if user is member of team.
516-
516+
517517
Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership).
518518
"""
519519
from litellm.proxy._types import LiteLLM_TeamMembership
520-
520+
521521
if prisma_client is None:
522522
raise Exception("No db connected")
523523

524524
if user_id is None or team_id is None:
525525
return None
526-
526+
527527
_key = "team_membership:{}:{}".format(user_id, team_id)
528-
528+
529529
# check if in cache
530530
cached_membership_obj = await user_api_key_cache.async_get_cache(key=_key)
531531
if cached_membership_obj is not None:
532532
if isinstance(cached_membership_obj, dict):
533533
return LiteLLM_TeamMembership(**cached_membership_obj)
534534
elif isinstance(cached_membership_obj, LiteLLM_TeamMembership):
535535
return cached_membership_obj
536-
536+
537537
# else, check db
538538
try:
539539
response = await prisma_client.db.litellm_teammembership.find_unique(
@@ -545,15 +545,17 @@ async def get_team_membership(
545545
return None
546546

547547
# save the team membership object to cache
548-
await user_api_key_cache.async_set_cache(
549-
key=_key, value=response
550-
)
548+
await user_api_key_cache.async_set_cache(key=_key, value=response)
551549

552550
_response = LiteLLM_TeamMembership(**response.dict())
553551

554552
return _response
555553
except Exception:
556-
verbose_proxy_logger.exception("Error getting team membership for user_id: %s, team_id: %s", user_id, team_id)
554+
verbose_proxy_logger.exception(
555+
"Error getting team membership for user_id: %s, team_id: %s",
556+
user_id,
557+
team_id,
558+
)
557559
return None
558560

559561

@@ -1203,57 +1205,14 @@ async def get_org_object(
12031205
)
12041206

12051207

1206-
def _can_object_call_model(
1207-
model: Union[str, List[str]],
1208+
def _check_model_access_helper(
1209+
model: str,
12081210
llm_router: Optional[Router],
12091211
models: List[str],
12101212
team_model_aliases: Optional[Dict[str, str]] = None,
12111213
team_id: Optional[str] = None,
12121214
object_type: Literal["user", "team", "key", "org"] = "user",
1213-
fallback_depth: int = 0,
12141215
) -> Literal[True]:
1215-
"""
1216-
Checks if token can call a given model
1217-
1218-
Args:
1219-
- model: str
1220-
- llm_router: Optional[Router]
1221-
- models: List[str]
1222-
- team_model_aliases: Optional[Dict[str, str]]
1223-
- object_type: Literal["user", "team", "key", "org"]. We use the object type to raise the correct exception type
1224-
1225-
Returns:
1226-
- True: if token allowed to call model
1227-
1228-
Raises:
1229-
- Exception: If token not allowed to call model
1230-
"""
1231-
if fallback_depth >= DEFAULT_MAX_RECURSE_DEPTH:
1232-
raise Exception(
1233-
"Unable to parse model, max fallback depth exceeded - received model: {}".format(
1234-
model
1235-
)
1236-
)
1237-
if isinstance(model, list):
1238-
for m in model:
1239-
_can_object_call_model(
1240-
model=m,
1241-
llm_router=llm_router,
1242-
models=models,
1243-
team_model_aliases=team_model_aliases,
1244-
team_id=team_id,
1245-
object_type=object_type,
1246-
fallback_depth=fallback_depth + 1,
1247-
)
1248-
return True
1249-
1250-
if model in litellm.model_alias_map:
1251-
model = litellm.model_alias_map[model]
1252-
elif llm_router and model in llm_router.model_group_alias:
1253-
_model = llm_router._get_model_from_alias(model)
1254-
if _model:
1255-
model = _model
1256-
12571216
## check if model in allowed model names
12581217
from collections import defaultdict
12591218

@@ -1301,11 +1260,81 @@ def _can_object_call_model(
13011260
param="model",
13021261
code=status.HTTP_401_UNAUTHORIZED,
13031262
)
1263+
return True
1264+
13041265

1305-
verbose_proxy_logger.debug(
1306-
f"filtered allowed_models: {filtered_models}; models: {models}"
1266+
def _can_object_call_model(
1267+
model: Union[str, List[str]],
1268+
llm_router: Optional[Router],
1269+
models: List[str],
1270+
team_model_aliases: Optional[Dict[str, str]] = None,
1271+
team_id: Optional[str] = None,
1272+
object_type: Literal["user", "team", "key", "org"] = "user",
1273+
fallback_depth: int = 0,
1274+
) -> Literal[True]:
1275+
"""
1276+
Checks if token can call a given model
1277+
1278+
Args:
1279+
- model: str
1280+
- llm_router: Optional[Router]
1281+
- models: List[str]
1282+
- team_model_aliases: Optional[Dict[str, str]]
1283+
- object_type: Literal["user", "team", "key", "org"]. We use the object type to raise the correct exception type
1284+
1285+
Returns:
1286+
- True: if token allowed to call model
1287+
1288+
Raises:
1289+
- Exception: If token not allowed to call model
1290+
"""
1291+
if fallback_depth >= DEFAULT_MAX_RECURSE_DEPTH:
1292+
raise Exception(
1293+
"Unable to parse model, max fallback depth exceeded - received model: {}".format(
1294+
model
1295+
)
1296+
)
1297+
if isinstance(model, list):
1298+
for m in model:
1299+
_can_object_call_model(
1300+
model=m,
1301+
llm_router=llm_router,
1302+
models=models,
1303+
team_model_aliases=team_model_aliases,
1304+
team_id=team_id,
1305+
object_type=object_type,
1306+
fallback_depth=fallback_depth + 1,
1307+
)
1308+
return True
1309+
1310+
potential_models = [model]
1311+
if model in litellm.model_alias_map:
1312+
potential_models.append(litellm.model_alias_map[model])
1313+
elif llm_router and model in llm_router.model_group_alias:
1314+
_model = llm_router._get_model_from_alias(model)
1315+
if _model:
1316+
potential_models.append(_model)
1317+
1318+
## check model access for alias + underlying model - allow if either is in allowed models
1319+
for m in potential_models:
1320+
if _check_model_access_helper(
1321+
model=m,
1322+
llm_router=llm_router,
1323+
models=models,
1324+
team_model_aliases=team_model_aliases,
1325+
team_id=team_id,
1326+
object_type=object_type,
1327+
):
1328+
return True
1329+
1330+
raise ProxyException(
1331+
message=f"{object_type} not allowed to access model. This {object_type} can only access models={models}. Tried to access {model}",
1332+
type=ProxyErrorTypes.get_model_access_error_type_for_object(
1333+
object_type=object_type
1334+
),
1335+
param="model",
1336+
code=status.HTTP_401_UNAUTHORIZED,
13071337
)
1308-
return True
13091338

13101339

13111340
def _model_in_team_aliases(

0 commit comments

Comments
 (0)