diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index cf8b3d147f03..92c5b5e4a799 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1223,6 +1223,12 @@ class NewTeamRequest(TeamBase): team_member_budget: Optional[float] = ( None # allow user to set a budget for all team members ) + team_member_rpm_limit: Optional[int] = ( + None # allow user to set RPM limit for all team members + ) + team_member_tpm_limit: Optional[int] = ( + None # allow user to set TPM limit for all team members + ) team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" model_config = ConfigDict(protected_namespaces=()) @@ -1266,6 +1272,8 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase): guardrails: Optional[List[str]] = None object_permission: Optional[LiteLLM_ObjectPermissionBase] = None team_member_budget: Optional[float] = None + team_member_rpm_limit: Optional[int] = None + team_member_tpm_limit: Optional[int] = None team_member_key_duration: Optional[str] = None @@ -1758,10 +1766,14 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_blocked: bool = False soft_budget: Optional[float] = None team_model_aliases: Optional[Dict] = None - team_member_spend: Optional[float] = None team_member: Optional[Member] = None team_metadata: Optional[Dict] = None + # Team Member Specific Params + team_member_spend: Optional[float] = None + team_member_tpm_limit: Optional[int] = None + team_member_rpm_limit: Optional[int] = None + # End User Params end_user_id: Optional[str] = None end_user_tpm_limit: Optional[int] = None @@ -1850,8 +1862,7 @@ def get_litellm_internal_health_check_user_api_key_auth(cls) -> "UserAPIKeyAuth" key_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME, team_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME, ) - - + class UserInfoResponse(LiteLLMPydanticObjectBase): user_id: Optional[str] user_info: Optional[Union[dict, BaseModel]] @@ -2620,6 +2631,16 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase): spend: Optional[float] = 0.0 litellm_budget_table: Optional[LiteLLM_BudgetTable] + def safe_get_team_member_rpm_limit(self) -> Optional[int]: + if self.litellm_budget_table is not None: + return self.litellm_budget_table.rpm_limit + return None + + def safe_get_team_member_tpm_limit(self) -> Optional[int]: + if self.litellm_budget_table is not None: + return self.litellm_budget_table.tpm_limit + return None + #### Organization / Team Member Requests #### @@ -2984,6 +3005,7 @@ class JWTAuthBuilderResult(TypedDict): user_id: Optional[str] end_user_id: Optional[str] org_id: Optional[str] + team_membership: Optional[LiteLLM_TeamMembership] class ClientSideFallbackModel(TypedDict, total=False): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index b306512847fd..a2110ec58de3 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -35,6 +35,7 @@ LiteLLM_ObjectPermissionTable, LiteLLM_OrganizationMembershipTable, LiteLLM_OrganizationTable, + LiteLLM_TeamMembership, LiteLLM_TeamTable, LiteLLM_TeamTableCachedObj, LiteLLM_UserTable, @@ -501,6 +502,60 @@ def check_in_budget(end_user_obj: LiteLLM_EndUserTable): return None +@log_db_metrics +async def get_team_membership( + user_id: str, + team_id: str, + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, +) -> Optional["LiteLLM_TeamMembership"]: + """ + Returns team membership object if user is member of team. + + Do a isolated check for team membership vs. doing a combined key + team + user + team-membership check, as key might come in frequently for different users/teams. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (team membership). + """ + from litellm.proxy._types import LiteLLM_TeamMembership + + if prisma_client is None: + raise Exception("No db connected") + + if user_id is None or team_id is None: + return None + + _key = "team_membership:{}:{}".format(user_id, team_id) + + # check if in cache + cached_membership_obj = await user_api_key_cache.async_get_cache(key=_key) + if cached_membership_obj is not None: + if isinstance(cached_membership_obj, dict): + return LiteLLM_TeamMembership(**cached_membership_obj) + elif isinstance(cached_membership_obj, LiteLLM_TeamMembership): + return cached_membership_obj + + # else, check db + try: + response = await prisma_client.db.litellm_teammembership.find_unique( + where={"user_id_team_id": {"user_id": user_id, "team_id": team_id}}, + include={"litellm_budget_table": True}, + ) + + if response is None: + return None + + # save the team membership object to cache + await user_api_key_cache.async_set_cache( + key=_key, value=response + ) + + _response = LiteLLM_TeamMembership(**response.dict()) + + return _response + except Exception: + return None + + def model_in_access_group( model: str, team_models: Optional[List[str]], llm_router: Optional[Router] ) -> bool: diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index b8b568335193..4cda2bb8e3a1 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -28,6 +28,7 @@ LiteLLM_EndUserTable, LiteLLM_JWTAuth, LiteLLM_OrganizationTable, + LiteLLM_TeamMembership, LiteLLM_TeamTable, LiteLLM_UserTable, LitellmUserRoles, @@ -50,6 +51,7 @@ get_org_object, get_role_based_models, get_role_based_routes, + get_team_membership, get_team_object, get_user_object, ) @@ -707,6 +709,7 @@ async def check_admin_access( user_id=user_id, end_user_id=None, org_id=org_id, + team_membership=None, ) @staticmethod @@ -839,6 +842,7 @@ async def get_objects( user_email: Optional[str], org_id: Optional[str], end_user_id: Optional[str], + team_id: Optional[str], valid_user_email: Optional[bool], jwt_handler: JWTHandler, prisma_client: Optional[PrismaClient], @@ -850,6 +854,7 @@ async def get_objects( Optional[LiteLLM_UserTable], Optional[LiteLLM_OrganizationTable], Optional[LiteLLM_EndUserTable], + Optional[LiteLLM_TeamMembership], ]: """Get user, org, and end user objects""" org_object: Optional[LiteLLM_OrganizationTable] = None @@ -899,8 +904,23 @@ async def get_objects( if end_user_id else None ) + + team_membership_object: Optional[LiteLLM_TeamMembership] = None + if user_id and team_id: + team_membership_object = ( + await get_team_membership( + user_id=user_id, + team_id=team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if user_id and team_id + else None + ) - return user_object, org_object, end_user_object + return user_object, org_object, end_user_object, team_membership_object @staticmethod def validate_object_id( @@ -1125,11 +1145,12 @@ async def auth_builder( ) # Get other objects - user_object, org_object, end_user_object = await JWTAuthManager.get_objects( + user_object, org_object, end_user_object, team_membership_object = await JWTAuthManager.get_objects( user_id=user_id, user_email=user_email, org_id=org_id, end_user_id=end_user_id, + team_id=team_id, valid_user_email=valid_user_email, jwt_handler=jwt_handler, prisma_client=prisma_client, @@ -1165,6 +1186,8 @@ async def auth_builder( is_proxy_admin = True else: is_proxy_admin = False + + return JWTAuthBuilderResult( is_proxy_admin=is_proxy_admin, @@ -1177,4 +1200,5 @@ async def auth_builder( end_user_id=end_user_id, end_user_object=end_user_object, token=api_key, + team_membership=team_membership_object, ) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9efa904574a3..6dea634a804a 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -502,6 +502,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 end_user_object = result["end_user_object"] org_id = result["org_id"] token = result["token"] + team_membership: Optional[LiteLLM_TeamMembership] = result.get("team_membership", None) global_proxy_spend = await get_global_proxy_spend( litellm_proxy_admin_name=litellm_proxy_admin_name, @@ -536,6 +537,10 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 org_id=org_id, parent_otel_span=parent_otel_span, end_user_id=end_user_id, + user_tpm_limit=user_object.tpm_limit if user_object is not None else None, + user_rpm_limit=user_object.rpm_limit if user_object is not None else None, + team_member_rpm_limit=team_membership.safe_get_team_member_rpm_limit() if team_membership is not None else None, + team_member_tpm_limit=team_membership.safe_get_team_member_tpm_limit() if team_membership is not None else None, ) # run through common checks _ = await common_checks( diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index dde0542c7d23..b73765781ec6 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -448,6 +448,21 @@ async def async_pre_call_hook( }, ) ) + + # Team Member rate limits + if user_api_key_dict.user_id and (user_api_key_dict.team_member_rpm_limit is not None or user_api_key_dict.team_member_tpm_limit is not None): + team_member_value = f"{user_api_key_dict.team_id}:{user_api_key_dict.user_id}" + descriptors.append( + RateLimitDescriptor( + key="team_member", + value=team_member_value, + rate_limit={ + "requests_per_unit": user_api_key_dict.team_member_rpm_limit, + "tokens_per_unit": user_api_key_dict.team_member_tpm_limit, + "window_size": self.window_size, + }, + ) + ) # End user rate limits if user_api_key_dict.end_user_id and ( @@ -662,6 +677,16 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti total_tokens=total_tokens, ) ) + # Team Member TPM + if user_api_key_team_id and user_api_key_user_id: + pipeline_operations.extend( + self._create_pipeline_operations( + key="team_member", + value=f"{user_api_key_team_id}:{user_api_key_user_id}", + rate_limit_type="tokens", + total_tokens=total_tokens, + ) + ) # End User TPM if user_api_key_end_user_id: diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 718ca40e01d3..c2849f86584d 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -103,6 +103,137 @@ router = APIRouter() +class TeamMemberBudgetHandler: + """Helper class to handle team member budget, RPM, and TPM limit operations""" + + @staticmethod + def should_create_budget( + team_member_budget: Optional[float] = None, + team_member_rpm_limit: Optional[int] = None, + team_member_tpm_limit: Optional[int] = None, + ) -> bool: + """Check if any team member limits are provided""" + return any([ + team_member_budget is not None, + team_member_rpm_limit is not None, + team_member_tpm_limit is not None, + ]) + + @staticmethod + async def create_team_member_budget_table( + data: Union[NewTeamRequest, LiteLLM_TeamTable], + new_team_data_json: dict, + user_api_key_dict: UserAPIKeyAuth, + team_member_budget: Optional[float] = None, + team_member_rpm_limit: Optional[int] = None, + team_member_tpm_limit: Optional[int] = None, + ) -> dict: + """Create team member budget table with provided limits""" + from litellm.proxy._types import BudgetNewRequest + from litellm.proxy.management_endpoints.budget_management_endpoints import ( + new_budget, + ) + + if data.team_alias is not None: + budget_id = ( + f"team-{data.team_alias.replace(' ', '-')}-budget-{uuid.uuid4().hex}" + ) + else: + budget_id = f"team-budget-{uuid.uuid4().hex}" + + # Create budget request with all provided limits + budget_request = BudgetNewRequest( + budget_id=budget_id, + budget_duration=data.budget_duration, + ) + + if team_member_budget is not None: + budget_request.max_budget = team_member_budget + if team_member_rpm_limit is not None: + budget_request.rpm_limit = team_member_rpm_limit + if team_member_tpm_limit is not None: + budget_request.tpm_limit = team_member_tpm_limit + + team_member_budget_table = await new_budget( + budget_obj=budget_request, + user_api_key_dict=user_api_key_dict, + ) + + # Add team_member_budget_id as metadata field to team table + if new_team_data_json.get("metadata") is None: + new_team_data_json["metadata"] = {} + new_team_data_json["metadata"][ + "team_member_budget_id" + ] = team_member_budget_table.budget_id + + # Remove team member fields from new_team_data_json + TeamMemberBudgetHandler._clean_team_member_fields(new_team_data_json) + + return new_team_data_json + + @staticmethod + async def upsert_team_member_budget_table( + team_table: LiteLLM_TeamTable, + user_api_key_dict: UserAPIKeyAuth, + updated_kv: dict, + team_member_budget: Optional[float] = None, + team_member_rpm_limit: Optional[int] = None, + team_member_tpm_limit: Optional[int] = None, + ) -> dict: + """Upsert team member budget table with provided limits""" + from litellm.proxy._types import BudgetNewRequest + from litellm.proxy.management_endpoints.budget_management_endpoints import ( + update_budget, + ) + + if team_table.metadata is None: + team_table.metadata = {} + + team_member_budget_id = team_table.metadata.get("team_member_budget_id") + if team_member_budget_id is not None and isinstance(team_member_budget_id, str): + # Budget exists - create update request with only provided values + budget_request = BudgetNewRequest(budget_id=team_member_budget_id) + + if team_member_budget is not None: + budget_request.max_budget = team_member_budget + if team_member_rpm_limit is not None: + budget_request.rpm_limit = team_member_rpm_limit + if team_member_tpm_limit is not None: + budget_request.tpm_limit = team_member_tpm_limit + + budget_row = await update_budget( + budget_obj=budget_request, + user_api_key_dict=user_api_key_dict, + ) + verbose_proxy_logger.info( + f"Updated team member budget table: {budget_row.budget_id}, with team_member_budget={team_member_budget}, team_member_rpm_limit={team_member_rpm_limit}, team_member_tpm_limit={team_member_tpm_limit}" + ) + if updated_kv.get("metadata") is None: + updated_kv["metadata"] = {} + updated_kv["metadata"]["team_member_budget_id"] = budget_row.budget_id + + else: # budget does not exist + updated_kv = await TeamMemberBudgetHandler.create_team_member_budget_table( + data=team_table, + new_team_data_json=updated_kv, + user_api_key_dict=user_api_key_dict, + team_member_budget=team_member_budget, + team_member_rpm_limit=team_member_rpm_limit, + team_member_tpm_limit=team_member_tpm_limit, + ) + + # Remove team member fields from updated_kv + TeamMemberBudgetHandler._clean_team_member_fields(updated_kv) + return updated_kv + + @staticmethod + def _clean_team_member_fields(data_dict: dict) -> None: + """Remove team member fields from data dictionary""" + data_dict.pop("team_member_budget", None) + data_dict.pop("team_member_rpm_limit", None) + data_dict.pop("team_member_tpm_limit", None) + + def _is_available_team(team_id: str, user_api_key_dict: UserAPIKeyAuth) -> bool: if litellm.default_internal_user_params is None: return False @@ -136,93 +267,6 @@ async def get_all_team_memberships( return returned_tm -async def _create_team_member_budget_table( - data: Union[NewTeamRequest, LiteLLM_TeamTable], - new_team_data_json: dict, - user_api_key_dict: UserAPIKeyAuth, - team_member_budget: float, -) -> dict: - """Allows admin to create 1 budget, that applies to all team members""" - from litellm.proxy._types import BudgetNewRequest - from litellm.proxy.management_endpoints.budget_management_endpoints import ( - new_budget, - ) - - if data.team_alias is not None: - budget_id = ( - f"team-{data.team_alias.replace(' ', '-')}-budget-{uuid.uuid4().hex}" - ) - else: - budget_id = f"team-budget-{uuid.uuid4().hex}" - - team_member_budget_table = await new_budget( - budget_obj=BudgetNewRequest( - max_budget=team_member_budget, - budget_duration=data.budget_duration, - budget_id=budget_id, - ), - user_api_key_dict=user_api_key_dict, - ) - - # Add team_member_budget_id as metadata field to team table - if new_team_data_json.get("metadata") is None: - new_team_data_json["metadata"] = {} - new_team_data_json["metadata"][ - "team_member_budget_id" - ] = team_member_budget_table.budget_id - new_team_data_json.pop( - "team_member_budget", None - ) # remove team_member_budget from new_team_data_json - - return new_team_data_json - - -async def _upsert_team_member_budget_table( - team_table: LiteLLM_TeamTable, - user_api_key_dict: UserAPIKeyAuth, - team_member_budget: float, - updated_kv: dict, -) -> dict: - """ - Add budget if none exists - - If budget exists, update it - """ - from litellm.proxy._types import BudgetNewRequest - from litellm.proxy.management_endpoints.budget_management_endpoints import ( - update_budget, - ) - - if team_table.metadata is None: - team_table.metadata = {} - - team_member_budget_id = team_table.metadata.get("team_member_budget_id") - if team_member_budget_id is not None and isinstance(team_member_budget_id, str): - # Budget exists - budget_row = await update_budget( - budget_obj=BudgetNewRequest( - budget_id=team_member_budget_id, - max_budget=team_member_budget, - ), - user_api_key_dict=user_api_key_dict, - ) - verbose_proxy_logger.info( - f"Updated team member budget table: {budget_row.budget_id}, with team_member_budget={team_member_budget}" - ) - if updated_kv.get("metadata") is None: - updated_kv["metadata"] = {} - updated_kv["metadata"]["team_member_budget_id"] = budget_row.budget_id - - else: # budget does not exist - updated_kv = await _create_team_member_budget_table( - data=team_table, - new_team_data_json=updated_kv, - user_api_key_dict=user_api_key_dict, - team_member_budget=team_member_budget, - ) - updated_kv.pop("team_member_budget", None) - return updated_kv - #### TEAM MANAGEMENT #### @router.post( @@ -268,6 +312,8 @@ async def new_team( # noqa: PLR0915 - prompts: Optional[List[str]] - List of prompts that the team is allowed to use. - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. + - team_member_rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for individual team members. + - team_member_tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for individual team members. - team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo" - prompts: Optional[List[str]] - List of allowed prompts for the team. If specified, the team will only be able to use these specific prompts. @@ -421,12 +467,18 @@ async def new_team( # noqa: PLR0915 ## Create Team Member Budget Table data_json = data.json() - if data.team_member_budget is not None: - data_json = await _create_team_member_budget_table( + if TeamMemberBudgetHandler.should_create_budget( + team_member_budget=data.team_member_budget, + team_member_rpm_limit=data.team_member_rpm_limit, + team_member_tpm_limit=data.team_member_tpm_limit, + ): + data_json = await TeamMemberBudgetHandler.create_team_member_budget_table( data=data, new_team_data_json=data_json, user_api_key_dict=user_api_key_dict, team_member_budget=data.team_member_budget, + team_member_rpm_limit=data.team_member_rpm_limit, + team_member_tpm_limit=data.team_member_tpm_limit, ) ## ADD TO TEAM TABLE @@ -705,6 +757,8 @@ async def update_team( - prompts: Optional[List[str]] - List of prompts that the team is allowed to use. - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. + - team_member_rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for individual team members. + - team_member_tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for individual team members. - team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo" Example - update team TPM Limit @@ -797,15 +851,21 @@ async def update_team( # set the budget_reset_at in DB updated_kv["budget_reset_at"] = reset_at - if data.team_member_budget is not None: - updated_kv = await _upsert_team_member_budget_table( + if TeamMemberBudgetHandler.should_create_budget( + team_member_budget=data.team_member_budget, + team_member_rpm_limit=data.team_member_rpm_limit, + team_member_tpm_limit=data.team_member_tpm_limit, + ): + updated_kv = await TeamMemberBudgetHandler.upsert_team_member_budget_table( team_table=existing_team_row, + user_api_key_dict=user_api_key_dict, updated_kv=updated_kv, team_member_budget=data.team_member_budget, - user_api_key_dict=user_api_key_dict, + team_member_rpm_limit=data.team_member_rpm_limit, + team_member_tpm_limit=data.team_member_tpm_limit, ) else: - updated_kv.pop("team_member_budget", None) + TeamMemberBudgetHandler._clean_team_member_fields(updated_kv) # Check object permission if data.object_permission is not None: diff --git a/tests/test_litellm/proxy/auth/test_handle_jwt.py b/tests/test_litellm/proxy/auth/test_handle_jwt.py index 17efbdcf4b3e..10d43141c1fb 100644 --- a/tests/test_litellm/proxy/auth/test_handle_jwt.py +++ b/tests/test_litellm/proxy/auth/test_handle_jwt.py @@ -736,4 +736,114 @@ async def mock_get_team_object(*args, **kwargs): # type: ignore ) assert team_id == "team-1" - assert team_obj.team_id == "team-1" \ No newline at end of file + assert team_obj.team_id == "team-1" + + +@pytest.mark.asyncio +async def test_auth_builder_returns_team_membership_object(): + """ + Test that auth_builder returns the team_membership_object when user is a member of a team. + """ + # Setup test data + api_key = "test_jwt_token" + request_data = {"model": "gpt-4"} + general_settings = {"enforce_rbac": False} + route = "/chat/completions" + _team_id = "test_team_1" + _user_id = "test_user_1" + + # Create mock objects + from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership + + mock_team_membership = LiteLLM_TeamMembership( + user_id=_user_id, + team_id=_team_id, + budget_id="budget_123", + spend=10.5, + litellm_budget_table=LiteLLM_BudgetTable( + budget_id="budget_123", + rpm_limit=100, + tpm_limit=5000 + ) + ) + + user_object = LiteLLM_UserTable( + user_id=_user_id, + user_role=LitellmUserRoles.INTERNAL_USER + ) + + team_object = LiteLLM_TeamTable(team_id=_team_id) + + # Create mock JWT handler + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth() + + # Mock all the dependencies and method calls + with patch.object( + jwt_handler, "auth_jwt", new_callable=AsyncMock + ) as mock_auth_jwt, patch.object( + JWTAuthManager, "check_rbac_role", new_callable=AsyncMock + ) as mock_check_rbac, patch.object( + jwt_handler, "get_rbac_role", return_value=None + ) as mock_get_rbac, patch.object( + jwt_handler, "get_scopes", return_value=[] + ) as mock_get_scopes, patch.object( + jwt_handler, "get_object_id", return_value=None + ) as mock_get_object_id, patch.object( + JWTAuthManager, + "get_user_info", + new_callable=AsyncMock, + return_value=(_user_id, "test@example.com", True), + ) as mock_get_user_info, patch.object( + jwt_handler, "get_org_id", return_value=None + ) as mock_get_org_id, patch.object( + jwt_handler, "get_end_user_id", return_value=None + ) as mock_get_end_user_id, patch.object( + JWTAuthManager, "check_admin_access", new_callable=AsyncMock, return_value=None + ) as mock_check_admin, patch.object( + JWTAuthManager, + "find_and_validate_specific_team_id", + new_callable=AsyncMock, + return_value=(_team_id, team_object), + ) as mock_find_team, patch.object( + JWTAuthManager, "get_all_team_ids", return_value=set() + ) as mock_get_all_team_ids, patch.object( + JWTAuthManager, + "find_team_with_model_access", + new_callable=AsyncMock, + return_value=(None, None), + ) as mock_find_team_access, patch.object( + JWTAuthManager, + "get_objects", + new_callable=AsyncMock, + return_value=(user_object, None, None, mock_team_membership), + ) as mock_get_objects, patch.object( + JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock + ) as mock_map_user, patch.object( + JWTAuthManager, "validate_object_id", return_value=True + ) as mock_validate_object, patch.object( + JWTAuthManager, "sync_user_role_and_teams", new_callable=AsyncMock + ) as mock_sync_user: + # Set up the mock return values + mock_auth_jwt.return_value = {"sub": _user_id, "scope": ""} + + # Call the auth_builder method + result = await JWTAuthManager.auth_builder( + api_key=api_key, + jwt_handler=jwt_handler, + request_data=request_data, + general_settings=general_settings, + route=route, + prisma_client=None, + user_api_key_cache=None, + parent_otel_span=None, + proxy_logging_obj=None, + ) + + # Verify that team_membership_object is returned + assert result["team_membership"] is not None, "team_membership should be present" + assert result["team_membership"] == mock_team_membership, "team_membership should match the mock object" + assert result["team_membership"].user_id == _user_id, "team_membership user_id should match" + assert result["team_membership"].team_id == _team_id, "team_membership team_id should match" + assert result["team_membership"].budget_id == "budget_123", "team_membership budget_id should match" + assert result["team_membership"].spend == 10.5, "team_membership spend should match" \ No newline at end of file diff --git a/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py b/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py index f76bc225e525..3f7fdc55f362 100644 --- a/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py +++ b/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py @@ -722,3 +722,64 @@ async def mock_should_rate_limit(*args, **kwargs): assert ( should_rate_limit_called ), "should_rate_limit should be called when model-specific limits match requested model" + + +@pytest.mark.asyncio +async def test_team_member_rate_limits_v3(): + """ + Test that team member RPM/TPM rate limits are properly applied for team member combinations. + """ + _api_key = "sk-12345" + _api_key = hash_token(_api_key) + _team_id = "team_123" + _user_id = "user_456" + + user_api_key_dict = UserAPIKeyAuth( + api_key=_api_key, + team_id=_team_id, + user_id=_user_id, + team_member_rpm_limit=10, + team_member_tpm_limit=1000, + ) + + local_cache = DualCache() + parallel_request_handler = _PROXY_MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(local_cache) + ) + + # Mock should_rate_limit to capture the descriptors + captured_descriptors = None + original_should_rate_limit = parallel_request_handler.should_rate_limit + + async def mock_should_rate_limit(descriptors, **kwargs): + nonlocal captured_descriptors + captured_descriptors = descriptors + # Return OK response to avoid HTTPException + return { + "overall_code": "OK", + "statuses": [] + } + + parallel_request_handler.should_rate_limit = mock_should_rate_limit + + # Test the pre-call hook + await parallel_request_handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=local_cache, + data={"model": "gpt-3.5-turbo"}, + call_type="", + ) + + # Verify team member descriptor was created + assert captured_descriptors is not None, "Rate limit descriptors should be captured" + + team_member_descriptor = None + for descriptor in captured_descriptors: + if descriptor["key"] == "team_member": + team_member_descriptor = descriptor + break + + assert team_member_descriptor is not None, "Team member descriptor should be present" + assert team_member_descriptor["value"] == f"{_team_id}:{_user_id}", "Team member value should combine team_id and user_id" + assert team_member_descriptor["rate_limit"]["requests_per_unit"] == 10, "Team member RPM limit should be set" + assert team_member_descriptor["rate_limit"]["tokens_per_unit"] == 1000, "Team member TPM limit should be set"