|
19 | 19 | import os |
20 | 20 |
|
21 | 21 | import yaml |
22 | | -from pydantic_settings import BaseSettings, SettingsConfigDict |
23 | | - |
24 | | - |
25 | | -def _yaml_config_settings_source(_: BaseSettings) -> Dict[str, Any]: |
26 | | - """Read config from config.yaml if it exists and map to flat settings |
27 | | - keys.""" |
28 | | - path = os.path.join(os.getcwd(), "config.yaml") |
29 | | - if not os.path.exists(path): |
30 | | - return {} |
31 | | - |
32 | | - try: |
33 | | - with open(path, "r", encoding="utf-8") as f: |
34 | | - y = yaml.safe_load(f) or {} |
35 | | - except Exception: |
36 | | - return {} |
37 | | - |
38 | | - result: Dict[str, Any] = {} |
39 | | - |
40 | | - # OpenAI block |
41 | | - if isinstance(y.get("openai"), dict): |
42 | | - o = y["openai"] |
43 | | - if "api_key" in o: |
44 | | - result["openai_api_key"] = o.get("api_key") |
45 | | - if "base_url" in o: |
46 | | - result["openai_base_url"] = o.get("base_url", "") |
47 | | - if "model" in o: |
48 | | - result["openai_model"] = o.get("model", "gpt-4o-mini") |
49 | | - if "max_tokens" in o: |
50 | | - result["max_tokens"] = int(o.get("max_tokens", 6000)) |
51 | | - if "temperature" in o: |
52 | | - result["temperature"] = float(o.get("temperature", 0.3)) |
53 | | - |
54 | | - # Tavily block |
55 | | - if isinstance(y.get("tavily"), dict): |
56 | | - t = y["tavily"] |
57 | | - if "api_key" in t: |
58 | | - result["tavily_api_key"] = t.get("api_key") |
59 | | - |
60 | | - # Search block |
61 | | - if isinstance(y.get("search"), dict): |
62 | | - s = y["search"] |
63 | | - if "max_results" in s: |
64 | | - result["max_search_results"] = int(s.get("max_results", 10)) |
65 | | - |
66 | | - # Execution block |
67 | | - if isinstance(y.get("execution"), dict): |
68 | | - ex = y["execution"] |
69 | | - if "max_rounds" in ex: |
70 | | - result["max_rounds"] = int(ex.get("max_rounds", 8)) |
71 | | - if "reports_dir" in ex: |
72 | | - result["reports_directory"] = ex.get("reports_dir", "reports") |
73 | | - if "max_searches_total" in ex: |
74 | | - result["max_searches_total"] = int(ex.get("max_searches_total", 6)) |
75 | | - |
76 | | - # Note: so_temperature is controlled via env (SO_TEMPERATURE) or default below |
77 | | - return result |
| 22 | +from pydantic_settings import ( |
| 23 | + BaseSettings, |
| 24 | + SettingsConfigDict, |
| 25 | + PydanticBaseSettingsSource, |
| 26 | +) |
| 27 | + |
| 28 | + |
| 29 | +class YamlConfigSettingsSource(PydanticBaseSettingsSource): |
| 30 | + """Custom settings source that reads from config.yaml file.""" |
| 31 | + |
| 32 | + def get_field_value(self, field_info, field_name: str): |
| 33 | + # Not used in this implementation |
| 34 | + return None |
| 35 | + |
| 36 | + def prepare_field_value(self, field_name: str, value, value_is_complex: bool): |
| 37 | + return value |
| 38 | + |
| 39 | + def __call__(self) -> Dict[str, Any]: |
| 40 | + """Read config from config.yaml if it exists and map to flat settings |
| 41 | + keys.""" |
| 42 | + path = os.path.join(os.getcwd(), "config.yaml") |
| 43 | + if not os.path.exists(path): |
| 44 | + return {} |
| 45 | + |
| 46 | + try: |
| 47 | + with open(path, "r", encoding="utf-8") as f: |
| 48 | + y = yaml.safe_load(f) or {} |
| 49 | + except Exception: |
| 50 | + return {} |
| 51 | + |
| 52 | + result: Dict[str, Any] = {} |
| 53 | + |
| 54 | + # OpenAI block |
| 55 | + if isinstance(y.get("openai"), dict): |
| 56 | + o = y["openai"] |
| 57 | + if "api_key" in o: |
| 58 | + result["openai_api_key"] = o.get("api_key") |
| 59 | + if "base_url" in o: |
| 60 | + result["openai_base_url"] = o.get("base_url", "") |
| 61 | + if "model" in o: |
| 62 | + result["openai_model"] = o.get("model", "gpt-4o-mini") |
| 63 | + if "max_tokens" in o: |
| 64 | + result["max_tokens"] = int(o.get("max_tokens", 6000)) |
| 65 | + if "temperature" in o: |
| 66 | + result["temperature"] = float(o.get("temperature", 0.3)) |
| 67 | + |
| 68 | + # Tavily block |
| 69 | + if isinstance(y.get("tavily"), dict): |
| 70 | + t = y["tavily"] |
| 71 | + if "api_key" in t: |
| 72 | + result["tavily_api_key"] = t.get("api_key") |
| 73 | + |
| 74 | + # Search block |
| 75 | + if isinstance(y.get("search"), dict): |
| 76 | + s = y["search"] |
| 77 | + if "max_results" in s: |
| 78 | + result["max_search_results"] = int(s.get("max_results", 10)) |
| 79 | + |
| 80 | + # Execution block |
| 81 | + if isinstance(y.get("execution"), dict): |
| 82 | + ex = y["execution"] |
| 83 | + if "max_rounds" in ex: |
| 84 | + result["max_rounds"] = int(ex.get("max_rounds", 8)) |
| 85 | + if "reports_dir" in ex: |
| 86 | + result["reports_directory"] = ex.get("reports_dir", "reports") |
| 87 | + if "max_searches_total" in ex: |
| 88 | + result["max_searches_total"] = int(ex.get("max_searches_total", 6)) |
| 89 | + |
| 90 | + # Note: so_temperature is controlled via env (SO_TEMPERATURE) or default below |
| 91 | + return result |
78 | 92 |
|
79 | 93 |
|
80 | 94 | class AppSettings(BaseSettings): |
@@ -116,7 +130,7 @@ def settings_customise_sources( |
116 | 130 | return ( |
117 | 131 | env_settings, |
118 | 132 | dotenv_settings, |
119 | | - _yaml_config_settings_source, |
| 133 | + YamlConfigSettingsSource(settings_cls), |
120 | 134 | init_settings, |
121 | 135 | file_secret_settings, |
122 | 136 | ) |
|
0 commit comments