@@ -513,27 +513,27 @@ async def get_team_membership(
513513) -> Optional ["LiteLLM_TeamMembership" ]:
514514 """
515515 Returns team membership object if user is member of team.
516-
516+
517517 Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership).
518518 """
519519 from litellm .proxy ._types import LiteLLM_TeamMembership
520-
520+
521521 if prisma_client is None :
522522 raise Exception ("No db connected" )
523523
524524 if user_id is None or team_id is None :
525525 return None
526-
526+
527527 _key = "team_membership:{}:{}" .format (user_id , team_id )
528-
528+
529529 # check if in cache
530530 cached_membership_obj = await user_api_key_cache .async_get_cache (key = _key )
531531 if cached_membership_obj is not None :
532532 if isinstance (cached_membership_obj , dict ):
533533 return LiteLLM_TeamMembership (** cached_membership_obj )
534534 elif isinstance (cached_membership_obj , LiteLLM_TeamMembership ):
535535 return cached_membership_obj
536-
536+
537537 # else, check db
538538 try :
539539 response = await prisma_client .db .litellm_teammembership .find_unique (
@@ -545,15 +545,17 @@ async def get_team_membership(
545545 return None
546546
547547 # save the team membership object to cache
548- await user_api_key_cache .async_set_cache (
549- key = _key , value = response
550- )
548+ await user_api_key_cache .async_set_cache (key = _key , value = response )
551549
552550 _response = LiteLLM_TeamMembership (** response .dict ())
553551
554552 return _response
555553 except Exception :
556- verbose_proxy_logger .exception ("Error getting team membership for user_id: %s, team_id: %s" , user_id , team_id )
554+ verbose_proxy_logger .exception (
555+ "Error getting team membership for user_id: %s, team_id: %s" ,
556+ user_id ,
557+ team_id ,
558+ )
557559 return None
558560
559561
@@ -1203,57 +1205,14 @@ async def get_org_object(
12031205 )
12041206
12051207
1206- def _can_object_call_model (
1207- model : Union [ str , List [ str ]] ,
1208+ def _check_model_access_helper (
1209+ model : str ,
12081210 llm_router : Optional [Router ],
12091211 models : List [str ],
12101212 team_model_aliases : Optional [Dict [str , str ]] = None ,
12111213 team_id : Optional [str ] = None ,
12121214 object_type : Literal ["user" , "team" , "key" , "org" ] = "user" ,
1213- fallback_depth : int = 0 ,
12141215) -> Literal [True ]:
1215- """
1216- Checks if token can call a given model
1217-
1218- Args:
1219- - model: str
1220- - llm_router: Optional[Router]
1221- - models: List[str]
1222- - team_model_aliases: Optional[Dict[str, str]]
1223- - object_type: Literal["user", "team", "key", "org"]. We use the object type to raise the correct exception type
1224-
1225- Returns:
1226- - True: if token allowed to call model
1227-
1228- Raises:
1229- - Exception: If token not allowed to call model
1230- """
1231- if fallback_depth >= DEFAULT_MAX_RECURSE_DEPTH :
1232- raise Exception (
1233- "Unable to parse model, max fallback depth exceeded - received model: {}" .format (
1234- model
1235- )
1236- )
1237- if isinstance (model , list ):
1238- for m in model :
1239- _can_object_call_model (
1240- model = m ,
1241- llm_router = llm_router ,
1242- models = models ,
1243- team_model_aliases = team_model_aliases ,
1244- team_id = team_id ,
1245- object_type = object_type ,
1246- fallback_depth = fallback_depth + 1 ,
1247- )
1248- return True
1249-
1250- if model in litellm .model_alias_map :
1251- model = litellm .model_alias_map [model ]
1252- elif llm_router and model in llm_router .model_group_alias :
1253- _model = llm_router ._get_model_from_alias (model )
1254- if _model :
1255- model = _model
1256-
12571216 ## check if model in allowed model names
12581217 from collections import defaultdict
12591218
@@ -1301,11 +1260,81 @@ def _can_object_call_model(
13011260 param = "model" ,
13021261 code = status .HTTP_401_UNAUTHORIZED ,
13031262 )
1263+ return True
1264+
13041265
1305- verbose_proxy_logger .debug (
1306- f"filtered allowed_models: { filtered_models } ; models: { models } "
1266+ def _can_object_call_model (
1267+ model : Union [str , List [str ]],
1268+ llm_router : Optional [Router ],
1269+ models : List [str ],
1270+ team_model_aliases : Optional [Dict [str , str ]] = None ,
1271+ team_id : Optional [str ] = None ,
1272+ object_type : Literal ["user" , "team" , "key" , "org" ] = "user" ,
1273+ fallback_depth : int = 0 ,
1274+ ) -> Literal [True ]:
1275+ """
1276+ Checks if token can call a given model
1277+
1278+ Args:
1279+ - model: str
1280+ - llm_router: Optional[Router]
1281+ - models: List[str]
1282+ - team_model_aliases: Optional[Dict[str, str]]
1283+ - object_type: Literal["user", "team", "key", "org"]. We use the object type to raise the correct exception type
1284+
1285+ Returns:
1286+ - True: if token allowed to call model
1287+
1288+ Raises:
1289+ - Exception: If token not allowed to call model
1290+ """
1291+ if fallback_depth >= DEFAULT_MAX_RECURSE_DEPTH :
1292+ raise Exception (
1293+ "Unable to parse model, max fallback depth exceeded - received model: {}" .format (
1294+ model
1295+ )
1296+ )
1297+ if isinstance (model , list ):
1298+ for m in model :
1299+ _can_object_call_model (
1300+ model = m ,
1301+ llm_router = llm_router ,
1302+ models = models ,
1303+ team_model_aliases = team_model_aliases ,
1304+ team_id = team_id ,
1305+ object_type = object_type ,
1306+ fallback_depth = fallback_depth + 1 ,
1307+ )
1308+ return True
1309+
1310+ potential_models = [model ]
1311+ if model in litellm .model_alias_map :
1312+ potential_models .append (litellm .model_alias_map [model ])
1313+ elif llm_router and model in llm_router .model_group_alias :
1314+ _model = llm_router ._get_model_from_alias (model )
1315+ if _model :
1316+ potential_models .append (_model )
1317+
1318+ ## check model access for alias + underlying model - allow if either is in allowed models
1319+ for m in potential_models :
1320+ if _check_model_access_helper (
1321+ model = m ,
1322+ llm_router = llm_router ,
1323+ models = models ,
1324+ team_model_aliases = team_model_aliases ,
1325+ team_id = team_id ,
1326+ object_type = object_type ,
1327+ ):
1328+ return True
1329+
1330+ raise ProxyException (
1331+ message = f"{ object_type } not allowed to access model. This { object_type } can only access models={ models } . Tried to access { model } " ,
1332+ type = ProxyErrorTypes .get_model_access_error_type_for_object (
1333+ object_type = object_type
1334+ ),
1335+ param = "model" ,
1336+ code = status .HTTP_401_UNAUTHORIZED ,
13071337 )
1308- return True
13091338
13101339
13111340def _model_in_team_aliases (
0 commit comments