Skip to content

Commit 8e76f8e

Browse files
authored
[Feat] Team Member Rate Limits + Support for using with JWT Auth (#13601)
* fix - assign tpm/rpm limit onJWT * add team member rpm/tpm limits * update - rate limiter v3 with team member rate limits * update utils * fixes for LiteLLM_BudgetTable * undo change * add TeamMemberBudgetHandler * add _process_team_member_budget_data * add get_team_membership * add safe_get_team_member_rpm_limit and safe_get_team_member_tpm_limit * LiteLLM_TeamMembership * add LiteLLM_TeamMembership rate limit for JWTs * fix * tests
1 parent 76d2592 commit 8e76f8e

File tree

8 files changed

+461
-99
lines changed

8 files changed

+461
-99
lines changed

litellm/proxy/_types.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,12 @@ class NewTeamRequest(TeamBase):
12231223
team_member_budget: Optional[float] = (
12241224
None # allow user to set a budget for all team members
12251225
)
1226+
team_member_rpm_limit: Optional[int] = (
1227+
None # allow user to set RPM limit for all team members
1228+
)
1229+
team_member_tpm_limit: Optional[int] = (
1230+
None # allow user to set TPM limit for all team members
1231+
)
12261232
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
12271233

12281234
model_config = ConfigDict(protected_namespaces=())
@@ -1266,6 +1272,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
12661272
guardrails: Optional[List[str]] = None
12671273
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
12681274
team_member_budget: Optional[float] = None
1275+
team_member_rpm_limit: Optional[int] = None
1276+
team_member_tpm_limit: Optional[int] = None
12691277
team_member_key_duration: Optional[str] = None
12701278

12711279

@@ -1758,10 +1766,14 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
17581766
team_blocked: bool = False
17591767
soft_budget: Optional[float] = None
17601768
team_model_aliases: Optional[Dict] = None
1761-
team_member_spend: Optional[float] = None
17621769
team_member: Optional[Member] = None
17631770
team_metadata: Optional[Dict] = None
17641771

1772+
# Team Member Specific Params
1773+
team_member_spend: Optional[float] = None
1774+
team_member_tpm_limit: Optional[int] = None
1775+
team_member_rpm_limit: Optional[int] = None
1776+
17651777
# End User Params
17661778
end_user_id: Optional[str] = None
17671779
end_user_tpm_limit: Optional[int] = None
@@ -1850,8 +1862,7 @@ def get_litellm_internal_health_check_user_api_key_auth(cls) -> "UserAPIKeyAuth"
18501862
key_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
18511863
team_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
18521864
)
1853-
1854-
1865+
18551866
class UserInfoResponse(LiteLLMPydanticObjectBase):
18561867
user_id: Optional[str]
18571868
user_info: Optional[Union[dict, BaseModel]]
@@ -2620,6 +2631,16 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
26202631
spend: Optional[float] = 0.0
26212632
litellm_budget_table: Optional[LiteLLM_BudgetTable]
26222633

2634+
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
2635+
if self.litellm_budget_table is not None:
2636+
return self.litellm_budget_table.rpm_limit
2637+
return None
2638+
2639+
def safe_get_team_member_tpm_limit(self) -> Optional[int]:
2640+
if self.litellm_budget_table is not None:
2641+
return self.litellm_budget_table.tpm_limit
2642+
return None
2643+
26232644

26242645
#### Organization / Team Member Requests ####
26252646

@@ -2984,6 +3005,7 @@ class JWTAuthBuilderResult(TypedDict):
29843005
user_id: Optional[str]
29853006
end_user_id: Optional[str]
29863007
org_id: Optional[str]
3008+
team_membership: Optional[LiteLLM_TeamMembership]
29873009

29883010

29893011
class ClientSideFallbackModel(TypedDict, total=False):

litellm/proxy/auth/auth_checks.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
LiteLLM_ObjectPermissionTable,
3636
LiteLLM_OrganizationMembershipTable,
3737
LiteLLM_OrganizationTable,
38+
LiteLLM_TeamMembership,
3839
LiteLLM_TeamTable,
3940
LiteLLM_TeamTableCachedObj,
4041
LiteLLM_UserTable,
@@ -501,6 +502,60 @@ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
501502
return None
502503

503504

505+
@log_db_metrics
506+
async def get_team_membership(
507+
user_id: str,
508+
team_id: str,
509+
prisma_client: Optional[PrismaClient],
510+
user_api_key_cache: DualCache,
511+
parent_otel_span: Optional[Span] = None,
512+
proxy_logging_obj: Optional[ProxyLogging] = None,
513+
) -> Optional["LiteLLM_TeamMembership"]:
514+
"""
515+
Returns team membership object if user is member of team.
516+
517+
Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership).
518+
"""
519+
from litellm.proxy._types import LiteLLM_TeamMembership
520+
521+
if prisma_client is None:
522+
raise Exception("No db connected")
523+
524+
if user_id is None or team_id is None:
525+
return None
526+
527+
_key = "team_membership:{}:{}".format(user_id, team_id)
528+
529+
# check if in cache
530+
cached_membership_obj = await user_api_key_cache.async_get_cache(key=_key)
531+
if cached_membership_obj is not None:
532+
if isinstance(cached_membership_obj, dict):
533+
return LiteLLM_TeamMembership(**cached_membership_obj)
534+
elif isinstance(cached_membership_obj, LiteLLM_TeamMembership):
535+
return cached_membership_obj
536+
537+
# else, check db
538+
try:
539+
response = await prisma_client.db.litellm_teammembership.find_unique(
540+
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
541+
include={"litellm_budget_table": True},
542+
)
543+
544+
if response is None:
545+
return None
546+
547+
# save the team membership object to cache
548+
await user_api_key_cache.async_set_cache(
549+
key=_key, value=response
550+
)
551+
552+
_response = LiteLLM_TeamMembership(**response.dict())
553+
554+
return _response
555+
except Exception:
556+
return None
557+
558+
504559
def model_in_access_group(
505560
model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
506561
) -> bool:

litellm/proxy/auth/handle_jwt.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
LiteLLM_EndUserTable,
2929
LiteLLM_JWTAuth,
3030
LiteLLM_OrganizationTable,
31+
LiteLLM_TeamMembership,
3132
LiteLLM_TeamTable,
3233
LiteLLM_UserTable,
3334
LitellmUserRoles,
@@ -50,6 +51,7 @@
5051
get_org_object,
5152
get_role_based_models,
5253
get_role_based_routes,
54+
get_team_membership,
5355
get_team_object,
5456
get_user_object,
5557
)
@@ -707,6 +709,7 @@ async def check_admin_access(
707709
user_id=user_id,
708710
end_user_id=None,
709711
org_id=org_id,
712+
team_membership=None,
710713
)
711714

712715
@staticmethod
@@ -839,6 +842,7 @@ async def get_objects(
839842
user_email: Optional[str],
840843
org_id: Optional[str],
841844
end_user_id: Optional[str],
845+
team_id: Optional[str],
842846
valid_user_email: Optional[bool],
843847
jwt_handler: JWTHandler,
844848
prisma_client: Optional[PrismaClient],
@@ -850,6 +854,7 @@ async def get_objects(
850854
Optional[LiteLLM_UserTable],
851855
Optional[LiteLLM_OrganizationTable],
852856
Optional[LiteLLM_EndUserTable],
857+
Optional[LiteLLM_TeamMembership],
853858
]:
854859
"""Get user, org, and end user objects"""
855860
org_object: Optional[LiteLLM_OrganizationTable] = None
@@ -899,8 +904,23 @@ async def get_objects(
899904
if end_user_id
900905
else None
901906
)
907+
908+
team_membership_object: Optional[LiteLLM_TeamMembership] = None
909+
if user_id and team_id:
910+
team_membership_object = (
911+
await get_team_membership(
912+
user_id=user_id,
913+
team_id=team_id,
914+
prisma_client=prisma_client,
915+
user_api_key_cache=user_api_key_cache,
916+
parent_otel_span=parent_otel_span,
917+
proxy_logging_obj=proxy_logging_obj,
918+
)
919+
if user_id and team_id
920+
else None
921+
)
902922

903-
return user_object, org_object, end_user_object
923+
return user_object, org_object, end_user_object, team_membership_object
904924

905925
@staticmethod
906926
def validate_object_id(
@@ -1125,11 +1145,12 @@ async def auth_builder(
11251145
)
11261146

11271147
# Get other objects
1128-
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
1148+
user_object, org_object, end_user_object, team_membership_object = await JWTAuthManager.get_objects(
11291149
user_id=user_id,
11301150
user_email=user_email,
11311151
org_id=org_id,
11321152
end_user_id=end_user_id,
1153+
team_id=team_id,
11331154
valid_user_email=valid_user_email,
11341155
jwt_handler=jwt_handler,
11351156
prisma_client=prisma_client,
@@ -1165,6 +1186,8 @@ async def auth_builder(
11651186
is_proxy_admin = True
11661187
else:
11671188
is_proxy_admin = False
1189+
1190+
11681191

11691192
return JWTAuthBuilderResult(
11701193
is_proxy_admin=is_proxy_admin,
@@ -1177,4 +1200,5 @@ async def auth_builder(
11771200
end_user_id=end_user_id,
11781201
end_user_object=end_user_object,
11791202
token=api_key,
1203+
team_membership=team_membership_object,
11801204
)

litellm/proxy/auth/user_api_key_auth.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
502502
end_user_object = result["end_user_object"]
503503
org_id = result["org_id"]
504504
token = result["token"]
505+
team_membership: Optional[LiteLLM_TeamMembership] = result.get("team_membership", None)
505506

506507
global_proxy_spend = await get_global_proxy_spend(
507508
litellm_proxy_admin_name=litellm_proxy_admin_name,
@@ -536,6 +537,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
536537
org_id=org_id,
537538
parent_otel_span=parent_otel_span,
538539
end_user_id=end_user_id,
540+
user_tpm_limit=user_object.tpm_limit if user_object is not None else None,
541+
user_rpm_limit=user_object.rpm_limit if user_object is not None else None,
542+
team_member_rpm_limit=team_membership.safe_get_team_member_rpm_limit() if team_membership is not None else None,
543+
team_member_tpm_limit=team_membership.safe_get_team_member_tpm_limit() if team_membership is not None else None,
539544
)
540545
# run through common checks
541546
_ = await common_checks(

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,21 @@ async def async_pre_call_hook(
448448
},
449449
)
450450
)
451+
452+
# Team Member rate limits
453+
if user_api_key_dict.user_id and (user_api_key_dict.team_member_rpm_limit is not None or user_api_key_dict.team_member_tpm_limit is not None):
454+
team_member_value = f"{user_api_key_dict.team_id}:{user_api_key_dict.user_id}"
455+
descriptors.append(
456+
RateLimitDescriptor(
457+
key="team_member",
458+
value=team_member_value,
459+
rate_limit={
460+
"requests_per_unit": user_api_key_dict.team_member_rpm_limit,
461+
"tokens_per_unit": user_api_key_dict.team_member_tpm_limit,
462+
"window_size": self.window_size,
463+
},
464+
)
465+
)
451466

452467
# End user rate limits
453468
if user_api_key_dict.end_user_id and (
@@ -662,6 +677,16 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
662677
total_tokens=total_tokens,
663678
)
664679
)
680+
# Team Member TPM
681+
if user_api_key_team_id and user_api_key_user_id:
682+
pipeline_operations.extend(
683+
self._create_pipeline_operations(
684+
key="team_member",
685+
value=f"{user_api_key_team_id}:{user_api_key_user_id}",
686+
rate_limit_type="tokens",
687+
total_tokens=total_tokens,
688+
)
689+
)
665690

666691
# End User TPM
667692
if user_api_key_end_user_id:

0 commit comments

Comments
 (0)