Skip to content

Commit ab8136d

Browse files
committed
feat: add ActionsClient.wait_for function
This function allows the users to wait for multiple actions in an efficient way. All actions are queried using a single call, which reduce the potential for running into rate limits.
1 parent 24f5008 commit ab8136d

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

hcloud/_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable, Iterator
4+
from itertools import islice
5+
from typing import TypeVar
6+
7+
T = TypeVar("T")
8+
9+
10+
def batched(iterable: Iterable[T], size: int) -> Iterator[tuple[T, ...]]:
11+
"""
12+
Returns a batch of the provided size from the provided iterable.
13+
"""
14+
iterator = iter(iterable)
15+
while True:
16+
batch = tuple(islice(iterator, size))
17+
if not batch:
18+
break
19+
yield batch

hcloud/actions/client.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import time
44
import warnings
5-
from typing import TYPE_CHECKING, Any, NamedTuple
5+
from typing import TYPE_CHECKING, Any, Callable, NamedTuple
66

7+
from .._utils import batched
78
from ..core import BoundModelBase, ClientEntityBase, Meta
89
from .domain import Action, ActionFailedException, ActionTimeoutException
910

@@ -129,6 +130,102 @@ class ActionsClient(ResourceActionsClient):
129130
def __init__(self, client: Client):
130131
super().__init__(client, None)
131132

133+
def _get_list_by_ids(self, ids: list[int]) -> list[BoundAction]:
134+
"""
135+
Get a list of Actions by their IDs.
136+
137+
:param ids: List of Action IDs to get.
138+
:raises ValueError: Raise when Action IDs were not found.
139+
:return: List of Actions.
140+
"""
141+
actions: list[BoundAction] = []
142+
143+
for ids_batch in batched(ids, 25):
144+
params: dict[str, Any] = {
145+
"id": ids_batch,
146+
}
147+
148+
response = self._client.request(
149+
method="GET",
150+
url="/actions",
151+
params=params,
152+
)
153+
154+
actions.extend(
155+
BoundAction(self._client.actions, action_data)
156+
for action_data in response["actions"]
157+
)
158+
159+
if len(ids) != len(actions):
160+
found_ids = [a.id for a in actions]
161+
not_found_ids = list(set(ids) - set(found_ids))
162+
163+
raise ValueError(
164+
f"actions not found: {', '.join(str(o) for o in not_found_ids)}"
165+
)
166+
167+
return actions
168+
169+
def wait_for_function(
170+
self,
171+
handle_update: Callable[[BoundAction], None],
172+
actions: list[Action | BoundAction],
173+
) -> list[BoundAction]:
174+
"""
175+
Waits until all Actions succeed by polling the API at the interval defined by
176+
the client's poll interval and function. An Action is considered as complete
177+
when its status is either "success" or "error".
178+
179+
The handle_update callback is called every time an Action is updated.
180+
181+
:param handle_update: Function called every time an Action is updated.
182+
:param actions: List of Actions to wait for.
183+
:raises: ActionFailedException when an Action failed.
184+
:return: List of succeeded Actions.
185+
"""
186+
running_ids = [a.id for a in actions]
187+
188+
completed: list[BoundAction] = []
189+
190+
retries = 0
191+
while len(running_ids):
192+
# pylint: disable=protected-access
193+
time.sleep(self._client._poll_interval_func(retries))
194+
retries += 1
195+
196+
updates = self._get_list_by_ids(running_ids)
197+
198+
for update in updates:
199+
if update.status != Action.STATUS_RUNNING:
200+
running_ids.remove(update.id)
201+
completed.append(update)
202+
203+
handle_update(update)
204+
205+
return completed
206+
207+
def wait_for(
208+
self,
209+
actions: list[Action | BoundAction],
210+
) -> list[BoundAction]:
211+
"""
212+
Waits until all Actions succeed by polling the API at the interval defined by
213+
the client's poll interval and function. An Action is considered as complete
214+
when its status is either "success" or "error".
215+
216+
If a single Action fails, the function will stop waiting and raise ActionFailedException.
217+
218+
:param actions: List of Actions to wait for.
219+
:raises: ActionFailedException when an Action failed.
220+
:return: List of succeeded Actions.
221+
"""
222+
223+
def handle_update(update: BoundAction) -> None:
224+
if update.status == Action.STATUS_ERROR:
225+
raise ActionFailedException(action=update)
226+
227+
return self.wait_for_function(handle_update, actions)
228+
132229
def get_list(
133230
self,
134231
status: list[str] | None = None,

tests/unit/actions/test_client.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,64 @@ def test_get_all(self, actions_client, generic_action_list, params):
197197
assert action2._client == actions_client._client.actions
198198
assert action2.id == 2
199199
assert action2.command == "stop_server"
200+
201+
def test_wait_for(self, actions_client: ActionsClient):
202+
actions = [Action(id=1), Action(id=2)]
203+
204+
# Speed up test by not really waiting
205+
actions_client._client._poll_interval_func = mock.MagicMock()
206+
actions_client._client._poll_interval_func.return_value = 0.1
207+
208+
actions_client._client.request.side_effect = [
209+
{
210+
"actions": [
211+
{"id": 1, "status": "running"},
212+
{"id": 2, "status": "success"},
213+
]
214+
},
215+
{
216+
"actions": [
217+
{"id": 1, "status": "success"},
218+
]
219+
},
220+
]
221+
222+
actions = actions_client.wait_for(actions)
223+
224+
actions_client._client.request.assert_has_calls(
225+
[
226+
mock.call(method="GET", url="/actions", params={"id": (1, 2)}),
227+
mock.call(method="GET", url="/actions", params={"id": (1,)}),
228+
]
229+
)
230+
231+
assert len(actions) == 2
232+
233+
def test_wait_for_error(self, actions_client: ActionsClient):
234+
actions = [Action(id=1), Action(id=2)]
235+
236+
# Speed up test by not really waiting
237+
actions_client._client._poll_interval_func = mock.MagicMock()
238+
actions_client._client._poll_interval_func.return_value = 0.1
239+
240+
actions_client._client.request.side_effect = [
241+
{
242+
"actions": [
243+
{"id": 1, "status": "running"},
244+
{
245+
"id": 2,
246+
"status": "error",
247+
"error": {"code": "failed", "message": "Action failed"},
248+
},
249+
]
250+
},
251+
]
252+
253+
with pytest.raises(ActionFailedException):
254+
actions_client.wait_for(actions)
255+
256+
actions_client._client.request.assert_has_calls(
257+
[
258+
mock.call(method="GET", url="/actions", params={"id": (1, 2)}),
259+
]
260+
)

tests/unit/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
from hcloud._utils import batched
4+
5+
6+
def test_batched():
7+
assert list(o for o in batched([1, 2, 3, 4, 5], 2)) == [(1, 2), (3, 4), (5,)]

0 commit comments

Comments
 (0)