diff --git a/runpod/cli/groups/config/functions.py b/runpod/cli/groups/config/functions.py index ba6adffe..6b7c21bb 100644 --- a/runpod/cli/groups/config/functions.py +++ b/runpod/cli/groups/config/functions.py @@ -9,6 +9,7 @@ from pathlib import Path import tomli as toml +import tomlkit CREDENTIAL_FILE = os.path.expanduser("~/.runpod/config.toml") @@ -30,16 +31,23 @@ def set_credentials(api_key: str, profile: str = "default", overwrite=False) -> os.makedirs(os.path.dirname(CREDENTIAL_FILE), exist_ok=True) Path(CREDENTIAL_FILE).touch(exist_ok=True) + with open(CREDENTIAL_FILE, "r", encoding="UTF-8") as cred_file: + try: + content = cred_file.read() + config = tomlkit.parse(content) if content.strip() else tomlkit.document() + except tomlkit.exceptions.ParseError as exc: + raise ValueError("~/.runpod/config.toml is not a valid TOML file.") from exc + if not overwrite: - with open(CREDENTIAL_FILE, "rb") as cred_file: - if profile in toml.load(cred_file): - raise ValueError( - "Profile already exists. Use `update_credentials` instead." - ) + if profile in config: + raise ValueError( + "Profile already exists. Use `update_credentials` instead." + ) + + config[profile] = {"api_key": api_key} with open(CREDENTIAL_FILE, "w", encoding="UTF-8") as cred_file: - cred_file.write("[" + profile + "]\n") - cred_file.write('api_key = "' + api_key + '"\n') + tomlkit.dump(config, cred_file) def check_credentials(profile: str = "default"): diff --git a/tests/test_cli/test_cli_groups/test_config_functions.py b/tests/test_cli/test_cli_groups/test_config_functions.py index 14c8418a..8c29b8f8 100644 --- a/tests/test_cli/test_cli_groups/test_config_functions.py +++ b/tests/test_cli/test_cli_groups/test_config_functions.py @@ -14,19 +14,31 @@ class TestConfig(unittest.TestCase): def setUp(self) -> None: self.sample_credentials = "[default]\n" 'api_key = "RUNPOD_API_KEY"\n' - @patch("runpod.cli.groups.config.functions.toml.load") - @patch("builtins.open", new_callable=mock_open()) - def test_set_credentials(self, mock_file, mock_toml_load): + @patch("runpod.cli.groups.config.functions.tomlkit.dump") + @patch("runpod.cli.groups.config.functions.tomlkit.document") + @patch("builtins.open", new_callable=mock_open, read_data="") + def test_set_credentials(self, mock_file, mock_document, mock_dump): """ Tests the set_credentials function. """ - mock_toml_load.return_value = "" + mock_document.side_effect = [{}, {"default": True}] functions.set_credentials("RUNPOD_API_KEY") - mock_file.assert_called_with(functions.CREDENTIAL_FILE, "w", encoding="UTF-8") + assert any( + call.args[0] == functions.CREDENTIAL_FILE + and call.args[1] == "r" + and call.kwargs.get("encoding") == "UTF-8" + for call in mock_file.call_args_list + ) + assert any( + call.args[0] == functions.CREDENTIAL_FILE + and call.args[1] == "w" + and call.kwargs.get("encoding") == "UTF-8" + for call in mock_file.call_args_list + ) + assert mock_dump.called with self.assertRaises(ValueError) as context: - mock_toml_load.return_value = {"default": True} functions.set_credentials("RUNPOD_API_KEY") self.assertEqual(