Skip to content

Commit 08238fe

Browse files
committed
make class structure in settings
1 parent cb3cba7 commit 08238fe

File tree

1 file changed

+71
-57
lines changed

1 file changed

+71
-57
lines changed

settings.py

Lines changed: 71 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,62 +19,76 @@
1919
import os
2020

2121
import yaml
22-
from pydantic_settings import BaseSettings, SettingsConfigDict
23-
24-
25-
def _yaml_config_settings_source(_: BaseSettings) -> Dict[str, Any]:
26-
"""Read config from config.yaml if it exists and map to flat settings
27-
keys."""
28-
path = os.path.join(os.getcwd(), "config.yaml")
29-
if not os.path.exists(path):
30-
return {}
31-
32-
try:
33-
with open(path, "r", encoding="utf-8") as f:
34-
y = yaml.safe_load(f) or {}
35-
except Exception:
36-
return {}
37-
38-
result: Dict[str, Any] = {}
39-
40-
# OpenAI block
41-
if isinstance(y.get("openai"), dict):
42-
o = y["openai"]
43-
if "api_key" in o:
44-
result["openai_api_key"] = o.get("api_key")
45-
if "base_url" in o:
46-
result["openai_base_url"] = o.get("base_url", "")
47-
if "model" in o:
48-
result["openai_model"] = o.get("model", "gpt-4o-mini")
49-
if "max_tokens" in o:
50-
result["max_tokens"] = int(o.get("max_tokens", 6000))
51-
if "temperature" in o:
52-
result["temperature"] = float(o.get("temperature", 0.3))
53-
54-
# Tavily block
55-
if isinstance(y.get("tavily"), dict):
56-
t = y["tavily"]
57-
if "api_key" in t:
58-
result["tavily_api_key"] = t.get("api_key")
59-
60-
# Search block
61-
if isinstance(y.get("search"), dict):
62-
s = y["search"]
63-
if "max_results" in s:
64-
result["max_search_results"] = int(s.get("max_results", 10))
65-
66-
# Execution block
67-
if isinstance(y.get("execution"), dict):
68-
ex = y["execution"]
69-
if "max_rounds" in ex:
70-
result["max_rounds"] = int(ex.get("max_rounds", 8))
71-
if "reports_dir" in ex:
72-
result["reports_directory"] = ex.get("reports_dir", "reports")
73-
if "max_searches_total" in ex:
74-
result["max_searches_total"] = int(ex.get("max_searches_total", 6))
75-
76-
# Note: so_temperature is controlled via env (SO_TEMPERATURE) or default below
77-
return result
22+
from pydantic_settings import (
23+
BaseSettings,
24+
SettingsConfigDict,
25+
PydanticBaseSettingsSource,
26+
)
27+
28+
29+
class YamlConfigSettingsSource(PydanticBaseSettingsSource):
30+
"""Custom settings source that reads from config.yaml file."""
31+
32+
def get_field_value(self, field_info, field_name: str):
33+
# Not used in this implementation
34+
return None
35+
36+
def prepare_field_value(self, field_name: str, value, value_is_complex: bool):
37+
return value
38+
39+
def __call__(self) -> Dict[str, Any]:
40+
"""Read config from config.yaml if it exists and map to flat settings
41+
keys."""
42+
path = os.path.join(os.getcwd(), "config.yaml")
43+
if not os.path.exists(path):
44+
return {}
45+
46+
try:
47+
with open(path, "r", encoding="utf-8") as f:
48+
y = yaml.safe_load(f) or {}
49+
except Exception:
50+
return {}
51+
52+
result: Dict[str, Any] = {}
53+
54+
# OpenAI block
55+
if isinstance(y.get("openai"), dict):
56+
o = y["openai"]
57+
if "api_key" in o:
58+
result["openai_api_key"] = o.get("api_key")
59+
if "base_url" in o:
60+
result["openai_base_url"] = o.get("base_url", "")
61+
if "model" in o:
62+
result["openai_model"] = o.get("model", "gpt-4o-mini")
63+
if "max_tokens" in o:
64+
result["max_tokens"] = int(o.get("max_tokens", 6000))
65+
if "temperature" in o:
66+
result["temperature"] = float(o.get("temperature", 0.3))
67+
68+
# Tavily block
69+
if isinstance(y.get("tavily"), dict):
70+
t = y["tavily"]
71+
if "api_key" in t:
72+
result["tavily_api_key"] = t.get("api_key")
73+
74+
# Search block
75+
if isinstance(y.get("search"), dict):
76+
s = y["search"]
77+
if "max_results" in s:
78+
result["max_search_results"] = int(s.get("max_results", 10))
79+
80+
# Execution block
81+
if isinstance(y.get("execution"), dict):
82+
ex = y["execution"]
83+
if "max_rounds" in ex:
84+
result["max_rounds"] = int(ex.get("max_rounds", 8))
85+
if "reports_dir" in ex:
86+
result["reports_directory"] = ex.get("reports_dir", "reports")
87+
if "max_searches_total" in ex:
88+
result["max_searches_total"] = int(ex.get("max_searches_total", 6))
89+
90+
# Note: so_temperature is controlled via env (SO_TEMPERATURE) or default below
91+
return result
7892

7993

8094
class AppSettings(BaseSettings):
@@ -116,7 +130,7 @@ def settings_customise_sources(
116130
return (
117131
env_settings,
118132
dotenv_settings,
119-
_yaml_config_settings_source,
133+
YamlConfigSettingsSource(settings_cls),
120134
init_settings,
121135
file_secret_settings,
122136
)

0 commit comments

Comments
 (0)