Skip to content

Commit 5e5de21

Browse files
committed
PR #736 addressing some review notes (exclude bearer from repr)
and some code style tweaks
1 parent e33347c commit 5e5de21

File tree

4 files changed

+70
-59
lines changed

4 files changed

+70
-59
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ def _job_update_loop(
553553
not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy()
554554
if len(not_started) > 0:
555555
# Check number of jobs running at each backend
556+
# TODO: should "created" be included in here? Calling this "running" is quite misleading then.
556557
running = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"])
557558
stats["job_db get_by_status"] += 1
558559
per_backend = running.groupby("backend_name").size().to_dict()
@@ -570,7 +571,7 @@ def _job_update_loop(
570571
stats["job_db persist"] += 1
571572
total_added += 1
572573

573-
self._process_threadworker_updates(self._worker_pool, job_db, stats)
574+
self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats)
574575

575576
# TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads?
576577
for job, row in jobs_done:
@@ -641,7 +642,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
641642
root_url=job_con.root_url,
642643
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
643644
job_id=job.job_id,
644-
df_idx = i
645+
df_idx=i,
645646
)
646647
_log.info(f"Submitting task {task} to thread pool")
647648
self._worker_pool.submit_task(task)
@@ -659,8 +660,9 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
659660

660661
def _process_threadworker_updates(
661662
self,
662-
worker_pool: '_JobManagerWorkerThreadPool',
663-
job_db: 'JobDatabaseInterface',
663+
worker_pool: _JobManagerWorkerThreadPool,
664+
*,
665+
job_db: JobDatabaseInterface,
664666
stats: Dict[str, int],
665667
) -> None:
666668
"""
@@ -669,8 +671,6 @@ def _process_threadworker_updates(
669671
(matched by df_idx) are upserted via job_db.persist(). Any results
670672
targeting unknown df_idx indices are logged as errors but not persisted.
671673
672-
673-
674674
:param worker_pool: Thread-pool managing asynchronous Task executes
675675
:param job_db: Interface to append/upsert to the job database
676676
:param stats: Dictionary accumulating statistic counters
@@ -684,44 +684,43 @@ def _process_threadworker_updates(
684684
# Process database updates
685685
if res.db_update:
686686
try:
687-
updates.append({
688-
'id': res.job_id,
689-
'df_idx': res.df_idx,
690-
**res.db_update,
691-
})
687+
updates.append(
688+
{
689+
"id": res.job_id,
690+
"df_idx": res.df_idx,
691+
**res.db_update,
692+
}
693+
)
692694
except Exception as e:
693-
_log.error(f"Skipping invalid db_update '{res.db_update}' for job '{res.job_id}': {e}", )
694-
695+
_log.error(f"Skipping invalid db_update {res.db_update!r} for job {res.job_id!r}: {e}")
696+
695697
# Process stats updates
696698
if res.stats_update:
697699
try:
698700
for key, val in res.stats_update.items():
699701
count = int(val)
700702
stats[key] = stats.get(key, 0) + count
701703
except Exception as e:
702-
_log.error(
703-
f"Skipping invalid stats_update {res.stats_update} for job '{res.job_id}': {e}"
704-
)
704+
_log.error(f"Skipping invalid stats_update {res.stats_update!r} for job {res.job_id!r}: {e}")
705705

706706
# No valid updates: nothing to persist
707707
if not updates:
708708
return
709709

710710
# Build DataFrame of updates indexed by df_idx
711-
df_updates = pd.DataFrame(updates).set_index('df_idx', drop=True)
711+
df_updates = pd.DataFrame(updates).set_index("df_idx", drop=True)
712712

713713
# Determine which rows to upsert
714714
existing_indices = set(df_updates.index).intersection(job_db.read().index)
715715
if existing_indices:
716716
df_upsert = df_updates.loc[sorted(existing_indices)]
717717
job_db.persist(df_upsert)
718-
stats['job_db persist'] = stats.get('job_db persist', 0) + 1
718+
stats["job_db persist"] = stats.get("job_db persist", 0) + 1
719719

720720
# Any df_idx not in original index are errors
721721
missing = set(df_updates.index) - existing_indices
722722
if missing:
723-
_log.error(f"Skipping non-existing dataframe indiches: {sorted(missing)}")
724-
723+
_log.error(f"Skipping non-existing dataframe indices: {sorted(missing)}")
725724

726725
def on_job_done(self, job: BatchJob, row):
727726
"""
@@ -977,12 +976,11 @@ def get_by_status(self, statuses, max=None) -> pd.DataFrame:
977976

978977
def _merge_into_df(self, df: pd.DataFrame):
979978
if self._df is not None:
980-
self._df.update(df, overwrite=True)
979+
self._df.update(df, overwrite=True)
981980
else:
982981
self._df = df
983982

984983

985-
986984
class CsvJobDatabase(FullDataFrameJobDatabase):
987985
"""
988986
Persist/load job metadata with a CSV file.

openeo/extra/job_management/_thread_worker.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class _TaskResult:
3535
"""
3636

3737
job_id: str # Mandatory
38-
df_idx: int # Mandatory
38+
df_idx: int # Mandatory
3939
db_update: Dict[str, Any] = field(default_factory=dict) # Optional
4040
stats_update: Dict[str, int] = field(default_factory=dict) # Optional
4141

@@ -56,14 +56,14 @@ class Task(ABC):
5656
5757
:param job_id:
5858
Identifier of the job to start on the backend.
59-
59+
6060
:param df_idx:
6161
Index of the row of the job in the dataframe.
6262
6363
"""
6464

6565
job_id: str
66-
df_idx: int
66+
df_idx: int
6767

6868
@abstractmethod
6969
def execute(self) -> _TaskResult:
@@ -88,7 +88,7 @@ class ConnectedTask(Task):
8888
"""
8989

9090
root_url: str
91-
bearer_token: Optional[str]
91+
bearer_token: Optional[str] = field(default=None, repr=False)
9292

9393
def get_connection(self) -> openeo.Connection:
9494
connection = openeo.connect(self.root_url)
@@ -118,7 +118,7 @@ def execute(self) -> _TaskResult:
118118
_log.info(f"Job {self.job_id!r} started successfully")
119119
return _TaskResult(
120120
job_id=self.job_id,
121-
df_idx = self.df_idx,
121+
df_idx=self.df_idx,
122122
db_update={"status": "queued"},
123123
stats_update={"job start": 1},
124124
)
@@ -127,9 +127,9 @@ def execute(self) -> _TaskResult:
127127
# TODO: more insights about the failure (e.g. the exception) are just logged, but lost from the result
128128
return _TaskResult(
129129
job_id=self.job_id,
130-
df_idx = self.df_idx,
130+
df_idx=self.df_idx,
131131
db_update={"status": "start_failed"},
132-
stats_update={"start_job error": 1}
132+
stats_update={"start_job error": 1},
133133
)
134134

135135

@@ -188,7 +188,7 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe
188188
_log.exception(f"Threaded task {task!r} failed: {e!r}")
189189
result = _TaskResult(
190190
job_id=task.job_id,
191-
df_idx = task.df_idx,
191+
df_idx=task.df_idx,
192192
db_update={"status": "threaded task failed"},
193193
stats_update={"threaded task failed": 1},
194194
)

tests/extra/job_management/test_job_management.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,14 @@ class DummyTask(Task):
9494
"""
9595

9696
def __init__(self, job_id, df_idx, db_update, stats_update):
97-
super().__init__(job_id=job_id, df_idx = df_idx)
97+
super().__init__(job_id=job_id, df_idx=df_idx)
9898
self._db_update = db_update or {}
9999
self._stats_update = stats_update or {}
100100

101101
def execute(self) -> _TaskResult:
102-
103102
return _TaskResult(
104103
job_id=self.job_id,
105-
df_idx = self.df_idx,
104+
df_idx=self.df_idx,
106105
db_update=self._db_update,
107106
stats_update=self._stats_update,
108107
)
@@ -739,7 +738,7 @@ def get_status(job_id, current_status):
739738
# Mock sleep() to skip one hour at a time instead of actually sleeping
740739
with mock.patch.object(openeo.extra.job_management.time, "sleep", new=lambda s: time_machine.shift(60 * 60)):
741740
job_manager.run_jobs(df=df, start_job=self._create_year_job, job_db=job_db_path)
742-
741+
743742
final_df = CsvJobDatabase(job_db_path).read()
744743

745744
# Validate running_start_time is a valid datetime object
@@ -759,10 +758,12 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
759758
# Invalid index (not in DB)
760759
pool.submit_task(DummyTask("j-missing", df_idx=4, db_update={"status": "created"}, stats_update=None))
761760

762-
df_initial = pd.DataFrame({
763-
"id": ["j-0", "j-1", "j-2", "j-3"],
764-
"status": ["created", "created", "created", "created"],
765-
})
761+
df_initial = pd.DataFrame(
762+
{
763+
"id": ["j-0", "j-1", "j-2", "j-3"],
764+
"status": ["created", "created", "created", "created"],
765+
}
766+
)
766767
job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial)
767768

768769
mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs")
@@ -786,7 +787,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog):
786787
assert stats["job_db persist"] == 1
787788

788789
# Assert error log for invalid index
789-
assert any("Skipping non-existing dataframe indiches" in msg for msg in caplog.messages)
790+
assert any("Skipping non-existing dataframe indices" in msg for msg in caplog.messages)
790791

791792
def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
792793
pool = _JobManagerWorkerThreadPool(max_workers=2)
@@ -796,13 +797,12 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog):
796797
job_db = CsvJobDatabase(tmp_path / "jobs.csv").initialize_from_df(df_initial)
797798
mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs")
798799

799-
mgr._process_threadworker_updates(pool, job_db, stats)
800+
mgr._process_threadworker_updates(pool, job_db=job_db, stats=stats)
800801

801802
df_final = job_db.read()
802803
assert df_final.loc[0, "status"] == "created"
803804
assert stats == {}
804805

805-
806806
def test_logs_on_invalid_update(self, tmp_path, caplog):
807807
pool = _JobManagerWorkerThreadPool(max_workers=2)
808808
stats = collections.defaultdict(int)
@@ -824,7 +824,7 @@ def execute(self):
824824
mgr = MultiBackendJobManager(root_dir=tmp_path / "jobs")
825825

826826
with caplog.at_level(logging.ERROR):
827-
mgr._process_threadworker_updates(pool, job_db, stats)
827+
mgr._process_threadworker_updates(pool, job_db=job_db, stats=stats)
828828

829829
# DB should remain unchanged
830830
df_final = job_db.read()

tests/extra/job_management/test_thread_worker.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def dummy_backend(requests_mock) -> DummyBackend:
2525

2626
class TestTaskResult:
2727
def test_default(self):
28-
result = _TaskResult(job_id="j-123", df_idx = 0)
28+
result = _TaskResult(job_id="j-123", df_idx=0)
2929
assert result.job_id == "j-123"
30-
assert result.df_idx ==0
30+
assert result.df_idx == 0
3131
assert result.db_update == {}
3232
assert result.stats_update == {}
3333

@@ -37,12 +37,14 @@ def test_start_success(self, dummy_backend, caplog):
3737
caplog.set_level(logging.WARNING)
3838
job = dummy_backend.connection.create_job(process_graph={})
3939

40-
task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0")
40+
task = _JobStartTask(
41+
job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0"
42+
)
4143
result = task.execute()
4244

4345
assert result == _TaskResult(
4446
job_id="job-000",
45-
df_idx = 0,
47+
df_idx=0,
4648
db_update={"status": "queued"},
4749
stats_update={"job start": 1},
4850
)
@@ -54,7 +56,9 @@ def test_start_failure(self, dummy_backend, caplog):
5456
job = dummy_backend.connection.create_job(process_graph={})
5557
dummy_backend.setup_job_start_failure()
5658

57-
task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0")
59+
task = _JobStartTask(
60+
job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token="h4ll0"
61+
)
5862
result = task.execute()
5963

6064
assert result == _TaskResult(
@@ -68,7 +72,13 @@ def test_start_failure(self, dummy_backend, caplog):
6872
"Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting " "for you, buddy')"
6973
]
7074

71-
75+
@pytest.mark.parametrize("serializer", [repr, str])
76+
def test_hide_token(self, serializer):
77+
secret = "Secret!"
78+
task = _JobStartTask(job_id="job-123", df_idx=0, root_url="https://example.com", bearer_token=secret)
79+
serialized = serializer(task)
80+
assert "job-123" in serialized
81+
assert secret not in serialized
7282

7383

7484
class NopTask(Task):
@@ -107,8 +117,6 @@ def execute(self) -> _TaskResult:
107117
return _TaskResult(job_id=self.job_id, df_idx=self.df_idx, db_update={"status": "all fine"})
108118

109119

110-
111-
112120
class TestJobManagerWorkerThreadPool:
113121
@pytest.fixture
114122
def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]:
@@ -146,7 +154,7 @@ def test_submit_and_process_with_error(self, worker_pool):
146154
assert results == [
147155
_TaskResult(
148156
job_id="j-666",
149-
df_idx = 0,
157+
df_idx=0,
150158
db_update={"status": "threaded task failed"},
151159
stats_update={"threaded task failed": 1},
152160
),
@@ -163,7 +171,7 @@ def test_submit_and_process_iterative(self, worker_pool):
163171
worker_pool.submit_task(NopTask(job_id="j-22", df_idx=22))
164172
worker_pool.submit_task(NopTask(job_id="j-222", df_idx=222))
165173
results, remaining = worker_pool.process_futures(timeout=1)
166-
assert results == [_TaskResult(job_id="j-22", df_idx=22), _TaskResult(job_id="j-222", df_idx=222)]
174+
assert results == [_TaskResult(job_id="j-22", df_idx=22), _TaskResult(job_id="j-222", df_idx=222)]
167175
assert remaining == 0
168176

169177
def test_submit_multiple_simple(self, worker_pool):
@@ -204,7 +212,7 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool):
204212
events[0].set()
205213
results, remaining = worker_pool.process_futures(timeout=0.1)
206214
assert results == [
207-
_TaskResult(job_id="j-0", df_idx = 0, db_update={"status": "all fine"}),
215+
_TaskResult(job_id="j-0", df_idx=0, db_update={"status": "all fine"}),
208216
]
209217
assert remaining == n - 1
210218

@@ -213,10 +221,13 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool):
213221
events[j].set()
214222
results, remaining = worker_pool.process_futures(timeout=0.1)
215223
assert results == [
216-
_TaskResult(job_id="j-1", df_idx = 1, db_update={"status": "all fine"}),
217-
_TaskResult(job_id="j-2", df_idx = 2, db_update={"status": "all fine"}),
224+
_TaskResult(job_id="j-1", df_idx=1, db_update={"status": "all fine"}),
225+
_TaskResult(job_id="j-2", df_idx=2, db_update={"status": "all fine"}),
218226
_TaskResult(
219-
job_id="j-3", df_idx = 3, db_update={"status": "threaded task failed"}, stats_update={"threaded task failed": 1}
227+
job_id="j-3",
228+
df_idx=3,
229+
db_update={"status": "threaded task failed"},
230+
stats_update={"threaded task failed": 1},
220231
),
221232
]
222233
assert remaining == 1
@@ -226,7 +237,7 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool):
226237
events[j].set()
227238
results, remaining = worker_pool.process_futures(timeout=0.1)
228239
assert results == [
229-
_TaskResult(job_id="j-4", df_idx = 4, db_update={"status": "all fine"}),
240+
_TaskResult(job_id="j-4", df_idx=4, db_update={"status": "all fine"}),
230241
]
231242
assert remaining == 0
232243

@@ -252,7 +263,7 @@ def test_job_start_task(self, worker_pool, dummy_backend, caplog):
252263
assert results == [
253264
_TaskResult(
254265
job_id="job-000",
255-
df_idx = 0,
266+
df_idx=0,
256267
db_update={"status": "queued"},
257268
stats_update={"job start": 1},
258269
)
@@ -270,7 +281,9 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog):
270281

271282
results, remaining = worker_pool.process_futures(timeout=1)
272283
assert results == [
273-
_TaskResult(job_id="job-000", df_idx=0, db_update={"status": "start_failed"}, stats_update={"start_job error": 1})
284+
_TaskResult(
285+
job_id="job-000", df_idx=0, db_update={"status": "start_failed"}, stats_update={"start_job error": 1}
286+
)
274287
]
275288
assert remaining == 0
276289
assert caplog.messages == [

0 commit comments

Comments
 (0)