Skip to content

Commit be42cd7

Browse files
committed
Issue 793: Support partial updates in STACAPIJobDatabase.persist
related to PR #736
1 parent 4dfdf77 commit be42cd7

File tree

3 files changed

+93
-36
lines changed

3 files changed

+93
-36
lines changed

openeo/extra/job_management/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def _job_update_loop(
538538
self._launch_job(start_job, df=not_started, i=i, backend_name=backend_name, stats=stats)
539539
stats["job launch"] += 1
540540

541-
job_db.persist(not_started.loc[i : i + 1])
541+
job_db.persist(not_started.loc[[i]])
542542
stats["job_db persist"] += 1
543543
total_added += 1
544544

openeo/extra/job_management/stac_job_db.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,16 @@ def exists(self) -> bool:
5757
def _normalize_df(self, df: pd.DataFrame) -> pd.DataFrame:
5858
"""
5959
Normalize the given dataframe to be compatible with :py:class:`MultiBackendJobManager`
60-
by adding the default columns and setting the index.
60+
by adding the default columns and using the STAC item ids as index values.
6161
"""
6262
df = MultiBackendJobManager._normalize_df(df)
63-
# If the user doesn't specify the item_id column, we will use the index.
64-
if "item_id" not in df.columns:
65-
df = df.reset_index(names=["item_id"])
63+
64+
if isinstance(df.index, pd.RangeIndex) and "item_id" in df.columns:
65+
# Support legacy usage: default (autoincrement) index and an "item_id" column -> copy over as index
66+
df.index = df["item_id"]
67+
68+
# Make sure the index (of item ids) are strings, to play well with (py)STAC schemas
69+
df.index = df.index.astype(str)
6670
return df
6771

6872
def initialize_from_df(self, df: pd.DataFrame, *, on_exists: str = "error"):
@@ -128,7 +132,7 @@ def item_from(self, series: pd.Series) -> pystac.Item:
128132
:return: pystac.Item
129133
"""
130134
series_dict = series.to_dict()
131-
item_id = series_dict.pop("item_id")
135+
item_id = str(series.name)
132136
item_dict = {}
133137
item_dict.setdefault("stac_version", pystac.get_stac_version())
134138
item_dict.setdefault("type", "Feature")
@@ -168,6 +172,13 @@ def count_by_status(self, statuses: Iterable[str] = ()) -> dict:
168172
else:
169173
return items["status"].value_counts().to_dict()
170174

175+
def _search_result_to_df(self, search_result: pystac_client.ItemSearch) -> pd.DataFrame:
176+
series = [self.series_from(item) for item in search_result.items()]
177+
# Note: `series_from` sets the item id as the series "name",
178+
# which ends up in the index of the dataframe
179+
df = pd.DataFrame(series)
180+
return df
181+
171182
def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> pd.DataFrame:
172183
if isinstance(statuses, str):
173184
statuses = {statuses}
@@ -180,35 +191,45 @@ def get_by_status(self, statuses: Iterable[str], max: Optional[int] = None) -> p
180191
filter=status_filter,
181192
max_items=max,
182193
)
194+
df = self._search_result_to_df(search_results)
183195

184-
series = [self.series_from(item) for item in search_results.items()]
185-
186-
df = pd.DataFrame(series).reset_index(names=["item_id"])
187-
if len(series) == 0:
196+
if df.shape[0] == 0:
188197
# TODO: What if default columns are overwritten by the user?
189198
df = self._normalize_df(
190199
df
191200
) # Even for an empty dataframe the default columns are required
192201
return df
193202

194203
def persist(self, df: pd.DataFrame):
204+
if df.empty:
205+
_log.warning("No data to persist in STAC API job database, skipping.")
206+
return
207+
195208
if not self.exists():
196209
spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
197210
temporal_extent = pystac.TemporalExtent([[None, None]])
198211
extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)
199212
c = pystac.Collection(id=self.collection_id, description="STAC API job database collection.", extent=extent)
200213
self._create_collection(c)
201214

202-
all_items = []
203-
if not df.empty:
204-
205-
def handle_row(series):
206-
item = self.item_from(series)
207-
all_items.append(item)
215+
# Merge updates with existing items (if any)
216+
existing_items = self.client.search(
217+
method="GET",
218+
collections=[self.collection_id],
219+
ids=[str(i) for i in df.index.tolist()],
220+
)
221+
existing_df = self._search_result_to_df(existing_items)
208222

209-
df.apply(handle_row, axis=1)
223+
if existing_df.empty:
224+
df_to_persist = df
225+
else:
226+
# Merge data on item_id (in the index)
227+
df_to_persist = existing_df
228+
df_to_persist.update(df, overwrite=True)
210229

211-
self._upload_items_bulk(self.collection_id, all_items)
230+
items_to_persist = [self.item_from(s) for _, s in df_to_persist.iterrows()]
231+
_log.info(f"Bulk upload of {len(items_to_persist)} items to STAC API collection {self.collection_id!r}")
232+
self._upload_items_bulk(self.collection_id, items_to_persist)
212233

213234
def _prepare_item(self, item: pystac.Item, collection_id: str):
214235
item.collection_id = collection_id

tests/extra/job_management/test_stac_job_db.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -479,21 +479,25 @@ def _post_collections_bulk_items(self, request, context):
479479
def _get_search(self, request, context):
480480
"""Handler of `GET /search` requests."""
481481
collections = request.qs["collections"][0].split(",")
482-
filter = request.qs["filter"][0] if "filter" in request.qs else None
483-
484-
if filter:
485-
# TODO: use a more robust CQL2-text parser?
486-
assert re.match(r"^\s*\"properties\.status\"='\w+'(\s+or\s+\"properties\.status\"='\w+')*\s*$", filter)
487-
statuses = re.findall(r"\"properties\.status\"='(\w+)'", filter)
488-
else:
489-
statuses = None
490-
491482
items = [
492483
item
493484
for cid in collections
494485
for item in self.items.get(cid, {}).values()
495-
if statuses is None or item.get("properties", {}).get("status") in statuses
496486
]
487+
if "ids" in request.qs:
488+
[ids] = request.qs["ids"]
489+
ids = set(ids.split(","))
490+
items = [i for i in items if i.get("id") in ids]
491+
if "filter" in request.qs:
492+
[property_filter] = request.qs["filter"]
493+
# TODO: use a more robust CQL2-text parser?
494+
assert request.qs["filter-lang"] == ["cql2-text"]
495+
assert re.match(
496+
r"^\s*\"properties\.status\"='\w+'(\s+or\s+\"properties\.status\"='\w+')*\s*$", property_filter
497+
)
498+
statuses = set(re.findall(r"\"properties\.status\"='(\w+)'", property_filter))
499+
items = [i for i in items if i.get("properties", {}).get("status") in statuses]
500+
497501
return {
498502
"type": "FeatureCollection",
499503
"features": items,
@@ -502,27 +506,59 @@ def _get_search(self, request, context):
502506

503507

504508
def test_run_jobs_basic(tmp_path, dummy_backend_foo, requests_mock, sleep_mock):
505-
job_manager = MultiBackendJobManager(root_dir=tmp_path, poll_sleep=2)
506-
job_manager.add_backend("foo", connection=dummy_backend_foo.connection)
507-
508509
stac_api_url = "http://stacapi.test"
509510
dummy_stac_api = DummyStacApi(root_url=stac_api_url, requests_mock=requests_mock)
510511

512+
# Initialize job db
511513
job_db = STACAPIJobDatabase(collection_id="collection-123", stac_root_url=stac_api_url)
512514
df = pd.DataFrame(
513-
{
514-
"item_id": ["item-2024", "item-2025"],
515-
"year": [2024, 2025],
516-
}
515+
{"year": [2024, 2025]},
516+
index=["item-2024", "item-2025"],
517517
)
518518
job_db.initialize_from_df(df=df)
519+
assert dummy_stac_api.items == {
520+
"collection-123": {
521+
"item-2024": dirty_equals.IsPartialDict(
522+
{
523+
"type": "Feature",
524+
"id": "item-2024",
525+
"properties": dirty_equals.IsPartialDict(
526+
{
527+
"year": 2024,
528+
"id": None,
529+
"status": "not_started",
530+
"backend_name": None,
531+
}
532+
),
533+
}
534+
),
535+
"item-2025": dirty_equals.IsPartialDict(
536+
{
537+
"type": "Feature",
538+
"id": "item-2025",
539+
"properties": dirty_equals.IsPartialDict(
540+
{
541+
"year": 2025,
542+
"id": None,
543+
"status": "not_started",
544+
"backend_name": None,
545+
}
546+
),
547+
}
548+
),
549+
}
550+
}
519551

552+
# Set up job manager
553+
job_manager = MultiBackendJobManager(root_dir=tmp_path, poll_sleep=2)
554+
job_manager.add_backend("foo", connection=dummy_backend_foo.connection)
555+
556+
# Run job manager loop
520557
def create_job(row, connection, **kwargs):
521558
year = int(row["year"])
522559
pg = {"dummy1": {"process_id": "dummy", "arguments": {"year": year}, "result": True}}
523560
job = connection.create_job(pg)
524561
return job
525-
526562
run_stats = job_manager.run_jobs(job_db=job_db, start_job=create_job)
527563

528564
assert run_stats == dirty_equals.IsPartialDict(

0 commit comments

Comments
 (0)