Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,12 @@ class NewTeamRequest(TeamBase):
team_member_budget: Optional[float] = (
None # allow user to set a budget for all team members
)
team_member_rpm_limit: Optional[int] = (
None # allow user to set RPM limit for all team members
)
team_member_tpm_limit: Optional[int] = (
None # allow user to set TPM limit for all team members
)
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"

model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -1266,6 +1272,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase):
guardrails: Optional[List[str]] = None
object_permission: Optional[LiteLLM_ObjectPermissionBase] = None
team_member_budget: Optional[float] = None
team_member_rpm_limit: Optional[int] = None
team_member_tpm_limit: Optional[int] = None
team_member_key_duration: Optional[str] = None


Expand Down Expand Up @@ -1758,10 +1766,14 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
team_blocked: bool = False
soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None
team_member_spend: Optional[float] = None
team_member: Optional[Member] = None
team_metadata: Optional[Dict] = None

# Team Member Specific Params
team_member_spend: Optional[float] = None
team_member_tpm_limit: Optional[int] = None
team_member_rpm_limit: Optional[int] = None

# End User Params
end_user_id: Optional[str] = None
end_user_tpm_limit: Optional[int] = None
Expand Down Expand Up @@ -1850,8 +1862,7 @@ def get_litellm_internal_health_check_user_api_key_auth(cls) -> "UserAPIKeyAuth"
key_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
team_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
)



class UserInfoResponse(LiteLLMPydanticObjectBase):
user_id: Optional[str]
user_info: Optional[Union[dict, BaseModel]]
Expand Down Expand Up @@ -2620,6 +2631,16 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
spend: Optional[float] = 0.0
litellm_budget_table: Optional[LiteLLM_BudgetTable]

def safe_get_team_member_rpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.rpm_limit
return None

def safe_get_team_member_tpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:
return self.litellm_budget_table.tpm_limit
return None


#### Organization / Team Member Requests ####

Expand Down Expand Up @@ -2984,6 +3005,7 @@ class JWTAuthBuilderResult(TypedDict):
user_id: Optional[str]
end_user_id: Optional[str]
org_id: Optional[str]
team_membership: Optional[LiteLLM_TeamMembership]


class ClientSideFallbackModel(TypedDict, total=False):
Expand Down
55 changes: 55 additions & 0 deletions litellm/proxy/auth/auth_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LiteLLM_ObjectPermissionTable,
LiteLLM_OrganizationMembershipTable,
LiteLLM_OrganizationTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
LiteLLM_UserTable,
Expand Down Expand Up @@ -501,6 +502,60 @@ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
return None


@log_db_metrics
async def get_team_membership(
user_id: str,
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
parent_otel_span: Optional[Span] = None,
proxy_logging_obj: Optional[ProxyLogging] = None,
) -> Optional["LiteLLM_TeamMembership"]:
"""
Returns team membership object if user is member of team.

Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership).
"""
from litellm.proxy._types import LiteLLM_TeamMembership

if prisma_client is None:
raise Exception("No db connected")

if user_id is None or team_id is None:
return None

_key = "team_membership:{}:{}".format(user_id, team_id)

# check if in cache
cached_membership_obj = await user_api_key_cache.async_get_cache(key=_key)
if cached_membership_obj is not None:
if isinstance(cached_membership_obj, dict):
return LiteLLM_TeamMembership(**cached_membership_obj)
elif isinstance(cached_membership_obj, LiteLLM_TeamMembership):
return cached_membership_obj

# else, check db
try:
response = await prisma_client.db.litellm_teammembership.find_unique(
where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}},
include={"litellm_budget_table": True},
)

if response is None:
return None

# save the team membership object to cache
await user_api_key_cache.async_set_cache(
key=_key, value=response
)

_response = LiteLLM_TeamMembership(**response.dict())

return _response
except Exception:
return None


def model_in_access_group(
model: str, team_models: Optional[List[str]], llm_router: Optional[Router]
) -> bool:
Expand Down
28 changes: 26 additions & 2 deletions litellm/proxy/auth/handle_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_OrganizationTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LitellmUserRoles,
Expand All @@ -50,6 +51,7 @@
get_org_object,
get_role_based_models,
get_role_based_routes,
get_team_membership,
get_team_object,
get_user_object,
)
Expand Down Expand Up @@ -707,6 +709,7 @@ async def check_admin_access(
user_id=user_id,
end_user_id=None,
org_id=org_id,
team_membership=None,
)

@staticmethod
Expand Down Expand Up @@ -839,6 +842,7 @@ async def get_objects(
user_email: Optional[str],
org_id: Optional[str],
end_user_id: Optional[str],
team_id: Optional[str],
valid_user_email: Optional[bool],
jwt_handler: JWTHandler,
prisma_client: Optional[PrismaClient],
Expand All @@ -850,6 +854,7 @@ async def get_objects(
Optional[LiteLLM_UserTable],
Optional[LiteLLM_OrganizationTable],
Optional[LiteLLM_EndUserTable],
Optional[LiteLLM_TeamMembership],
]:
"""Get user, org, and end user objects"""
org_object: Optional[LiteLLM_OrganizationTable] = None
Expand Down Expand Up @@ -899,8 +904,23 @@ async def get_objects(
if end_user_id
else None
)

team_membership_object: Optional[LiteLLM_TeamMembership] = None
if user_id and team_id:
team_membership_object = (
await get_team_membership(
user_id=user_id,
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if user_id and team_id
else None
)

return user_object, org_object, end_user_object
return user_object, org_object, end_user_object, team_membership_object

@staticmethod
def validate_object_id(
Expand Down Expand Up @@ -1125,11 +1145,12 @@ async def auth_builder(
)

# Get other objects
user_object, org_object, end_user_object = await JWTAuthManager.get_objects(
user_object, org_object, end_user_object, team_membership_object = await JWTAuthManager.get_objects(
user_id=user_id,
user_email=user_email,
org_id=org_id,
end_user_id=end_user_id,
team_id=team_id,
valid_user_email=valid_user_email,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
Expand Down Expand Up @@ -1165,6 +1186,8 @@ async def auth_builder(
is_proxy_admin = True
else:
is_proxy_admin = False



return JWTAuthBuilderResult(
is_proxy_admin=is_proxy_admin,
Expand All @@ -1177,4 +1200,5 @@ async def auth_builder(
end_user_id=end_user_id,
end_user_object=end_user_object,
token=api_key,
team_membership=team_membership_object,
)
5 changes: 5 additions & 0 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
end_user_object = result["end_user_object"]
org_id = result["org_id"]
token = result["token"]
team_membership: Optional[LiteLLM_TeamMembership] = result.get("team_membership", None)

global_proxy_spend = await get_global_proxy_spend(
litellm_proxy_admin_name=litellm_proxy_admin_name,
Expand Down Expand Up @@ -536,6 +537,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
org_id=org_id,
parent_otel_span=parent_otel_span,
end_user_id=end_user_id,
user_tpm_limit=user_object.tpm_limit if user_object is not None else None,
user_rpm_limit=user_object.rpm_limit if user_object is not None else None,
team_member_rpm_limit=team_membership.safe_get_team_member_rpm_limit() if team_membership is not None else None,
team_member_tpm_limit=team_membership.safe_get_team_member_tpm_limit() if team_membership is not None else None,
)
# run through common checks
_ = await common_checks(
Expand Down
25 changes: 25 additions & 0 deletions litellm/proxy/hooks/parallel_request_limiter_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,21 @@ async def async_pre_call_hook(
},
)
)

# Team Member rate limits
if user_api_key_dict.user_id and (user_api_key_dict.team_member_rpm_limit is not None or user_api_key_dict.team_member_tpm_limit is not None):
team_member_value = f"{user_api_key_dict.team_id}:{user_api_key_dict.user_id}"
descriptors.append(
RateLimitDescriptor(
key="team_member",
value=team_member_value,
rate_limit={
"requests_per_unit": user_api_key_dict.team_member_rpm_limit,
"tokens_per_unit": user_api_key_dict.team_member_tpm_limit,
"window_size": self.window_size,
},
)
)

# End user rate limits
if user_api_key_dict.end_user_id and (
Expand Down Expand Up @@ -662,6 +677,16 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
total_tokens=total_tokens,
)
)
# Team Member TPM
if user_api_key_team_id and user_api_key_user_id:
pipeline_operations.extend(
self._create_pipeline_operations(
key="team_member",
value=f"{user_api_key_team_id}:{user_api_key_user_id}",
rate_limit_type="tokens",
total_tokens=total_tokens,
)
)

# End User TPM
if user_api_key_end_user_id:
Expand Down
Loading
Loading