|
9 | 9 | import importlib |
10 | 10 | import warnings |
11 | 11 | import backoff |
| 12 | +from argon2 import PasswordHasher |
| 13 | + |
| 14 | +ph = PasswordHasher() |
12 | 15 |
|
13 | 16 |
|
14 | 17 | def showwarning(message, category, filename, lineno, file=None, line=None): |
@@ -257,6 +260,7 @@ async def openai_exception_handler(request: Request, exc: ProxyException): |
257 | 260 | ui_access_mode: Literal["admin", "all"] = "all" |
258 | 261 | proxy_budget_rescheduler_min_time = 597 |
259 | 262 | proxy_budget_rescheduler_max_time = 605 |
| 263 | +litellm_master_key_hash = None |
260 | 264 | ### INITIALIZE GLOBAL LOGGING OBJECT ### |
261 | 265 | proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) |
262 | 266 | ### REDIS QUEUE ### |
@@ -336,31 +340,36 @@ async def user_api_key_auth( |
336 | 340 | Unprotected endpoints |
337 | 341 | """ |
338 | 342 | return UserAPIKeyAuth() |
| 343 | + elif route.startswith("/config/"): |
| 344 | + raise Exception(f"Only admin can modify config") |
339 | 345 |
|
340 | 346 | if api_key is None: # only require api key if master key is set |
341 | 347 | raise Exception(f"No api key passed in.") |
342 | 348 |
|
343 | | - if secrets.compare_digest(api_key, ""): |
| 349 | + if api_key == "": |
344 | 350 | # missing 'Bearer ' prefix |
345 | 351 | raise Exception( |
346 | 352 | f"Malformed API Key passed in. Ensure Key has `Bearer ` prefix. Passed in: {passed_in_key}" |
347 | 353 | ) |
348 | 354 |
|
349 | 355 | ### CHECK IF ADMIN ### |
350 | 356 | # note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead |
351 | | - is_master_key_valid = secrets.compare_digest(api_key, master_key) |
| 357 | + try: |
| 358 | + is_master_key_valid = ph.verify(litellm_master_key_hash, api_key) |
| 359 | + except Exception as e: |
| 360 | + is_master_key_valid = False |
| 361 | + |
352 | 362 | if is_master_key_valid: |
353 | 363 | return UserAPIKeyAuth( |
354 | 364 | api_key=master_key, |
355 | 365 | user_role="proxy_admin", |
356 | 366 | user_id=litellm_proxy_admin_name, |
357 | 367 | ) |
| 368 | + |
358 | 369 | if isinstance( |
359 | 370 | api_key, str |
360 | 371 | ): # if generated token, make sure it starts with sk-. |
361 | 372 | assert api_key.startswith("sk-") # prevent token hashes from being used |
362 | | - if route.startswith("/config/") and not is_master_key_valid: |
363 | | - raise Exception(f"Only admin can modify config") |
364 | 373 |
|
365 | 374 | if ( |
366 | 375 | prisma_client is None and custom_db_client is None |
@@ -1494,7 +1503,7 @@ async def load_config( |
1494 | 1503 | """ |
1495 | 1504 | Load config values into proxy global state |
1496 | 1505 | """ |
1497 | | - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode |
| 1506 | + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash |
1498 | 1507 |
|
1499 | 1508 | # Load existing config |
1500 | 1509 | config = await self.get_config(config_file_path=config_file_path) |
@@ -1759,6 +1768,9 @@ async def load_config( |
1759 | 1768 | ) |
1760 | 1769 | if master_key and master_key.startswith("os.environ/"): |
1761 | 1770 | master_key = litellm.get_secret(master_key) |
| 1771 | + |
| 1772 | + if master_key is not None and isinstance(master_key, str): |
| 1773 | + litellm_master_key_hash = ph.hash(master_key) |
1762 | 1774 | ### CUSTOM API KEY AUTH ### |
1763 | 1775 | ## pass filepath |
1764 | 1776 | custom_auth = general_settings.get("custom_auth", None) |
@@ -2837,6 +2849,7 @@ async def chat_completion( |
2837 | 2849 | response = await proxy_logging_obj.post_call_success_hook( |
2838 | 2850 | user_api_key_dict=user_api_key_dict, response=response |
2839 | 2851 | ) |
| 2852 | + |
2840 | 2853 | return response |
2841 | 2854 | except Exception as e: |
2842 | 2855 | traceback.print_exc() |
@@ -7032,42 +7045,45 @@ async def health_endpoint( |
7032 | 7045 | else, the health checks will be run on models when /health is called. |
7033 | 7046 | """ |
7034 | 7047 | global health_check_results, use_background_health_checks, user_model |
| 7048 | + try: |
| 7049 | + if llm_model_list is None: |
| 7050 | + # if no router set, check if user set a model using litellm --model ollama/llama2 |
| 7051 | + if user_model is not None: |
| 7052 | + healthy_endpoints, unhealthy_endpoints = await perform_health_check( |
| 7053 | + model_list=[], cli_model=user_model |
| 7054 | + ) |
| 7055 | + return { |
| 7056 | + "healthy_endpoints": healthy_endpoints, |
| 7057 | + "unhealthy_endpoints": unhealthy_endpoints, |
| 7058 | + "healthy_count": len(healthy_endpoints), |
| 7059 | + "unhealthy_count": len(unhealthy_endpoints), |
| 7060 | + } |
| 7061 | + raise HTTPException( |
| 7062 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 7063 | + detail={"error": "Model list not initialized"}, |
| 7064 | + ) |
7035 | 7065 |
|
7036 | | - if llm_model_list is None: |
7037 | | - # if no router set, check if user set a model using litellm --model ollama/llama2 |
7038 | | - if user_model is not None: |
| 7066 | + ### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ### |
| 7067 | + if len(user_api_key_dict.models) > 0: |
| 7068 | + allowed_model_names = user_api_key_dict.models |
| 7069 | + else: |
| 7070 | + allowed_model_names = [] # |
| 7071 | + if use_background_health_checks: |
| 7072 | + return health_check_results |
| 7073 | + else: |
7039 | 7074 | healthy_endpoints, unhealthy_endpoints = await perform_health_check( |
7040 | | - model_list=[], cli_model=user_model |
| 7075 | + llm_model_list, model |
7041 | 7076 | ) |
| 7077 | + |
7042 | 7078 | return { |
7043 | 7079 | "healthy_endpoints": healthy_endpoints, |
7044 | 7080 | "unhealthy_endpoints": unhealthy_endpoints, |
7045 | 7081 | "healthy_count": len(healthy_endpoints), |
7046 | 7082 | "unhealthy_count": len(unhealthy_endpoints), |
7047 | 7083 | } |
7048 | | - raise HTTPException( |
7049 | | - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
7050 | | - detail={"error": "Model list not initialized"}, |
7051 | | - ) |
7052 | | - |
7053 | | - ### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ### |
7054 | | - if len(user_api_key_dict.models) > 0: |
7055 | | - allowed_model_names = user_api_key_dict.models |
7056 | | - else: |
7057 | | - allowed_model_names = [] # |
7058 | | - if use_background_health_checks: |
7059 | | - return health_check_results |
7060 | | - else: |
7061 | | - healthy_endpoints, unhealthy_endpoints = await perform_health_check( |
7062 | | - llm_model_list, model |
7063 | | - ) |
7064 | | - |
7065 | | - return { |
7066 | | - "healthy_endpoints": healthy_endpoints, |
7067 | | - "unhealthy_endpoints": unhealthy_endpoints, |
7068 | | - "healthy_count": len(healthy_endpoints), |
7069 | | - "unhealthy_count": len(unhealthy_endpoints), |
7070 | | - } |
| 7084 | + except Exception as e: |
| 7085 | + traceback.print_exc() |
| 7086 | + raise e |
7071 | 7087 |
|
7072 | 7088 |
|
7073 | 7089 | @router.get( |
|
0 commit comments