Skip to content

Commit af55fd6

Browse files
committed
PR #736 various tweaks based on review notes
- split `Task` hierarchy for better separation of concerns - more tests for some basic classes - clean up unused fixtures from tests - DummyBackend: add setup_job_start_failure
1 parent 1b6f6a2 commit af55fd6

File tree

5 files changed

+146
-80
lines changed

5 files changed

+146
-80
lines changed

openeo/extra/job_management/_thread_worker.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Internal utilities to handle job management tasks through threads.
3+
"""
4+
15
import concurrent.futures
26
import logging
37
from abc import ABC, abstractmethod
@@ -9,7 +13,7 @@
913
_log = logging.getLogger(__name__)
1014

1115

12-
@dataclass
16+
@dataclass(frozen=True)
1317
class _TaskResult:
1418
"""
1519
Container for the result of a task execution.
@@ -32,92 +36,86 @@ class _TaskResult:
3236
stats_update: Dict[str, int] = field(default_factory=dict) # Optional
3337

3438

39+
@dataclass(frozen=True)
3540
class Task(ABC):
3641
"""
37-
Abstract base class for asynchronous tasks.
42+
Abstract base class for a unit of work associated with a job (identified by a job id)
43+
and to be processed by :py:classs:`_JobManagerWorkerThreadPool`.
44+
45+
Because the work is intended to be executed in a thread/process pool,
46+
it is recommended to keep the state of the task object as simple/immutable as possible
47+
(e.g. just some string/number attributes) and avoid sharing complex objects and state.
3848
39-
A task encapsulates a unit of work, typically executed asynchronously,
40-
and returns a `_TaskResult` with job-related metadata and updates.
49+
The main API for subclasses to implement is the `execute`method
50+
which should return a :py:class:`_TaskResult` object.
51+
with job-related metadata and updates.
4152
42-
Implementations must override the `execute` method to define the task logic.
53+
:param job_id:
54+
Identifier of the job to start on the backend.
4355
"""
4456

57+
# TODO: strictly speaking, a job id does not unambiguously identify a job when multiple backends are in play.
58+
job_id: str
59+
4560
@abstractmethod
4661
def execute(self) -> _TaskResult:
4762
"""Execute the task and return a raw result"""
4863
pass
4964

5065

51-
@dataclass
52-
class _JobStartTask(Task):
66+
@dataclass(frozen=True)
67+
class ConnectedTask(Task):
5368
"""
54-
Task for starting a backend job asynchronously.
69+
Base class for tasks that involve an (authenticated) connection to a backend.
5570
56-
Connects to an OpenEO backend using the provided URL and optional token,
57-
retrieves the specified job, and attempts to start it.
58-
59-
Usage example:
60-
61-
.. code-block:: python
62-
63-
task = _JobStartTask(
64-
job_id="1234",
65-
root_url="https://openeo.test",
66-
bearer_token="secret"
67-
)
68-
result = task.execute()
69-
70-
:param job_id:
71-
Identifier of the job to start on the backend.
71+
Backend is specified by a root URL,
72+
and (optional) authentication is done through an openEO-style bearer token.
7273
7374
:param root_url:
7475
The root URL of the OpenEO backend to connect to.
7576
7677
:param bearer_token:
7778
Optional Bearer token used for authentication.
7879
79-
:raises ValueError:
80-
If any of the input parameters are invalid (e.g., empty strings).
8180
"""
8281

83-
job_id: str
8482
root_url: str
8583
bearer_token: Optional[str]
8684

87-
def __post_init__(self) -> None:
88-
# Validation remains unchanged
89-
if not isinstance(self.root_url, str) or not self.root_url.strip():
90-
raise ValueError(f"root_url must be a non-empty string, got {self.root_url!r}")
91-
if self.bearer_token is not None and (not isinstance(self.bearer_token, str) or not self.bearer_token.strip()):
92-
raise ValueError(f"bearer_token must be a non-empty string or None, got {self.bearer_token!r}")
93-
if not isinstance(self.job_id, str) or not self.job_id.strip():
94-
raise ValueError(f"job_id must be a non-empty string, got {self.job_id!r}")
85+
def get_connection(self) -> openeo.Connection:
86+
connection = openeo.connect(self.root_url)
87+
if self.bearer_token:
88+
connection.authenticate_bearer_token(self.bearer_token)
89+
return connection
90+
91+
92+
class _JobStartTask(ConnectedTask):
93+
"""
94+
Task for starting an openEO batch job (the `POST /jobs/<job_id>/result` request).
95+
"""
9596

9697
def execute(self) -> _TaskResult:
9798
"""
98-
Executes the job start process using the OpenEO connection.
99-
100-
Authenticates if a bearer token is provided, retrieves the job by ID,
101-
and attempts to start it.
99+
Start job identified by `job_id` on the backend.
102100
103101
:returns:
104102
A `_TaskResult` with status and statistics metadata, indicating
105103
success or failure of the job start.
106104
"""
105+
# TODO: move main try-except block to base class?
107106
try:
108-
conn = openeo.connect(self.root_url)
109-
if self.bearer_token:
110-
conn.authenticate_bearer_token(self.bearer_token)
111-
job = conn.job(self.job_id)
107+
job = self.get_connection().job(self.job_id)
108+
# TODO: only start when status is "queued"?
112109
job.start()
113-
_log.info(f"Job {self.job_id} started successfully")
110+
_log.info(f"Job {self.job_id!r} started successfully")
114111
return _TaskResult(
115112
job_id=self.job_id,
116113
db_update={"status": "queued"},
117114
stats_update={"job start": 1},
118115
)
119116
except Exception as e:
120-
_log.error(f"Failed to start job {self.job_id}: {e}")
117+
_log.error(f"Failed to start job {self.job_id!r}: {e!r}")
118+
# TODO: more insights about the failure (e.g. the exception) are just logged, but lost from the result
121119
return _TaskResult(
122120
job_id=self.job_id, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}
123121
)
@@ -175,13 +173,13 @@ def process_futures(self) -> List[_TaskResult]:
175173
if future in done:
176174
try:
177175
result = future.result()
178-
179176
except Exception as e:
180-
_log.exception(f"Error processing task: {e}")
177+
_log.exception(f"Failed to get result from future: {e}")
181178
result = _TaskResult(
182-
job_id=task.job_id, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}
179+
job_id=task.job_id,
180+
db_update={"status": "future.result() failed"},
181+
stats_update={"future.result() error": 1},
183182
)
184-
185183
results.append(result)
186184
else:
187185
to_keep.append((future, task))

openeo/rest/_testing.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class DummyBackend:
4848
"next_result",
4949
"next_validation_errors",
5050
"_forced_job_status",
51+
"_fail_on_job_start",
5152
"job_status_updater",
5253
"job_id_generator",
5354
"extra_job_metadata_fields",
@@ -73,6 +74,7 @@ def __init__(
7374
self.next_validation_errors = []
7475
self.extra_job_metadata_fields = []
7576
self._forced_job_status: Dict[str, str] = {}
77+
self._fail_on_job_start = {}
7678

7779
# Job status update hook:
7880
# callable that is called on starting a job, and getting job metadata
@@ -221,24 +223,51 @@ def _get_job_id(self, request) -> str:
221223
assert job_id in self.batch_jobs
222224
return job_id
223225

226+
def _set_job_status(self, job_id: str, status: str):
227+
"""Forced override of job status (e.g. for "canceled" or "error")"""
228+
self.batch_jobs[job_id]["status"] = self._forced_job_status[job_id] = status
229+
224230
def _get_job_status(self, job_id: str, current_status: str) -> str:
225231
if job_id in self._forced_job_status:
226232
return self._forced_job_status[job_id]
227233
return self.job_status_updater(job_id=job_id, current_status=current_status)
228234

235+
def setup_job_start_failure(
236+
self,
237+
*,
238+
job_id: Union[str, None] = None,
239+
status_code: int = 500,
240+
response_body: Union[None, str, dict] = None,
241+
):
242+
"""
243+
Setup for failure when starting a job.
244+
:param job_id: job id to fail on, or None (wildcard) for all jobs
245+
"""
246+
if response_body is None:
247+
response_body = {"code": "Internal", "message": "No job starting for you, buddy"}
248+
if not isinstance(response_body, bytes):
249+
response_body = json.dumps(response_body).encode("utf-8")
250+
self._fail_on_job_start[job_id] = {"status_code": status_code, "response_body": response_body}
251+
229252
def _handle_post_job_results(self, request, context):
230253
"""Handler of `POST /job/{job_id}/results` (start batch job)."""
231254
job_id = self._get_job_id(request)
232255
assert self.batch_jobs[job_id]["status"] == "created"
233-
self.batch_jobs[job_id]["status"] = self._get_job_status(
234-
job_id=job_id, current_status=self.batch_jobs[job_id]["status"]
235-
)
236-
context.status_code = HTTP_202_ACCEPTED
256+
failure = self._fail_on_job_start.get(job_id) or self._fail_on_job_start.get(None)
257+
if not failure:
258+
self.batch_jobs[job_id]["status"] = self._get_job_status(
259+
job_id=job_id, current_status=self.batch_jobs[job_id]["status"]
260+
)
261+
context.status_code = HTTP_202_ACCEPTED
262+
else:
263+
self._set_job_status(job_id=job_id, status="error")
264+
context.status_code = failure["status_code"]
265+
return failure["response_body"]
237266

238267
def _handle_get_job(self, request, context):
239268
"""Handler of `GET /job/{job_id}` (get batch job status and metadata)."""
240269
job_id = self._get_job_id(request)
241-
# Allow updating status with `job_status_setter` once job got past status "created"
270+
# Allow updating status with `job_status_updater` once job got past status "created"
242271
if self.batch_jobs[job_id]["status"] != "created":
243272
self.batch_jobs[job_id]["status"] = self._get_job_status(
244273
job_id=job_id, current_status=self.batch_jobs[job_id]["status"]
@@ -269,8 +298,7 @@ def _handle_get_job_results(self, request, context):
269298
def _handle_delete_job_results(self, request, context):
270299
"""Handler of `DELETE /job/{job_id}/results` (cancel job)."""
271300
job_id = self._get_job_id(request)
272-
self.batch_jobs[job_id]["status"] = "canceled"
273-
self._forced_job_status[job_id] = "canceled"
301+
self._set_job_status(job_id=job_id, status="canceled")
274302
context.status_code = HTTP_204_NO_CONTENT
275303

276304
def _handle_get_job_result_asset(self, request, context):

tests/extra/job_management/test_job_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class DummyTask(Task):
9494
"""
9595

9696
def __init__(self, job_id, db_update, stats_update, delay=0.0):
97-
self.job_id = job_id
97+
super().__init__(job_id=job_id)
9898
self._db_update = db_update or {}
9999
self._stats_update = stats_update or {}
100100
self._delay = delay

tests/extra/job_management/test_thread_worker.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,65 @@
1+
import logging
12
import time
23

3-
import pandas as pd
44
import pytest
55
import requests
66

7-
# Import the refactored classes and helper functions from your codebase.
8-
# Adjust the import paths as needed.
97
from openeo.extra.job_management._thread_worker import (
108
_JobManagerWorkerThreadPool,
119
_JobStartTask,
10+
_TaskResult,
1211
)
12+
from openeo.rest._testing import DummyBackend
1313

14-
# --- Fixtures and Helpers ---
14+
15+
@pytest.fixture
16+
def dummy_backend(requests_mock) -> DummyBackend:
17+
dummy = DummyBackend.at_url("https://foo.test", requests_mock=requests_mock)
18+
dummy.setup_simple_job_status_flow(queued=3, running=5)
19+
return dummy
20+
21+
22+
class TestTaskResult:
23+
def test_default(self):
24+
result = _TaskResult(job_id="j-123")
25+
assert result.job_id == "j-123"
26+
assert result.db_update == {}
27+
assert result.stats_update == {}
28+
29+
30+
class TestJobStartTask:
31+
def test_start_success(self, dummy_backend, caplog):
32+
caplog.set_level(logging.WARNING)
33+
job = dummy_backend.connection.create_job(process_graph={})
34+
35+
task = _JobStartTask(job_id=job.job_id, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0")
36+
result = task.execute()
37+
38+
assert result == _TaskResult(
39+
job_id="job-000",
40+
db_update={"status": "queued"},
41+
stats_update={"job start": 1},
42+
)
43+
assert job.status() == "queued"
44+
assert caplog.messages == []
45+
46+
def test_start_failure(self, dummy_backend, caplog):
47+
caplog.set_level(logging.WARNING)
48+
job = dummy_backend.connection.create_job(process_graph={})
49+
dummy_backend.setup_job_start_failure()
50+
51+
task = _JobStartTask(job_id=job.job_id, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0")
52+
result = task.execute()
53+
54+
assert result == _TaskResult(
55+
job_id="job-000",
56+
db_update={"status": "start_failed"},
57+
stats_update={"start_job error": 1},
58+
)
59+
assert job.status() == "error"
60+
assert caplog.messages == [
61+
"Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting " "for you, buddy')"
62+
]
1563

1664

1765
@pytest.fixture
@@ -22,23 +70,6 @@ def worker_pool():
2270
pool.shutdown()
2371

2472

25-
@pytest.fixture
26-
def sample_dataframe():
27-
"""Creates a pandas DataFrame for job tracking."""
28-
df = pd.DataFrame(
29-
[
30-
{"id": "job-123", "status": "queued_for_start", "other_field": "foo"},
31-
{"id": "job-456", "status": "queued_for_start", "other_field": "bar"},
32-
{"id": "job-789", "status": "other", "other_field": "baz"},
33-
]
34-
)
35-
return df
36-
37-
38-
@pytest.fixture
39-
def initial_stats():
40-
"""Returns a dictionary with initial stats counters."""
41-
return {"job start": 0, "job start failed": 0}
4273

4374

4475
@pytest.fixture
@@ -47,6 +78,7 @@ def successful_backend_mock(requests_mock):
4778
Returns a helper to set up a successful backend.
4879
Mocks a version check, job start, and job status check.
4980
"""
81+
# TODO: use DummyBackend here instead?
5082

5183
def _setup(root_url: str, job_id: str, status: str = "queued"):
5284
# Backend version check
@@ -67,7 +99,6 @@ def valid_task():
6799
return _JobStartTask(root_url="https://foo.test", bearer_token="test-token", job_id="test-job-123")
68100

69101

70-
import time
71102

72103

73104
def wait_for_results(worker_pool, timeout=3.0, interval=0.1):
@@ -85,7 +116,6 @@ def wait_for_results(worker_pool, timeout=3.0, interval=0.1):
85116
raise TimeoutError(f"Timed out after {timeout}s waiting for worker pool results.")
86117

87118

88-
# --- Tests for the Worker Thread Pool and Futures Postprocessing ---
89119

90120

91121
class TestJobManagerWorkerThreadPool:

tests/rest/test_testing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import re
2+
13
import pytest
24

5+
from openeo.rest import OpenEoApiError
36
from openeo.rest._testing import DummyBackend
47

58

@@ -94,3 +97,10 @@ def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120)
9497
assert job0.status() == "finished"
9598
assert job1.status() == "error"
9699
assert job2.status() == "finished"
100+
101+
def test_setup_job_start_failure(self, dummy_backend):
102+
job = dummy_backend.connection.create_job(process_graph={})
103+
dummy_backend.setup_job_start_failure()
104+
with pytest.raises(OpenEoApiError, match=re.escape("[500] Internal: No job starting for you, buddy")):
105+
job.start()
106+
assert job.status() == "error"

0 commit comments

Comments
 (0)