Skip to content

Commit 19cfd99

Browse files
committed
Support UDF based spatial/temporal extents in load_collection/load_stac
1 parent 9706528 commit 19cfd99

File tree

6 files changed

+221
-14
lines changed

6 files changed

+221
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- `MultiBackendJobManager`: add `download_results` option to enable/disable the automated download of job results once completed by the job manager ([#744](https://github.com/Open-EO/openeo-python-client/issues/744))
13+
- Support UDF based spatial and temporal extents in `load_collection`, `load_stac` and `filter_temporal` ([#831](https://github.com/Open-EO/openeo-python-client/pull/831))
1314

1415
### Changed
1516

openeo/rest/connection.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@
3535
import openeo
3636
from openeo.config import config_log, get_config_option
3737
from openeo.internal.documentation import openeo_process
38-
from openeo.internal.graph_building import FlatGraphableMixin, PGNode, as_flat_graph
38+
from openeo.internal.graph_building import (
39+
FlatGraphableMixin,
40+
PGNode,
41+
_FromNodeMixin,
42+
as_flat_graph,
43+
)
3944
from openeo.internal.jupyter import VisualDict, VisualList
4045
from openeo.internal.processes.builder import ProcessBuilderBase
4146
from openeo.internal.warnings import deprecated, legacy_alias
@@ -1186,8 +1191,8 @@ def load_collection(
11861191
self,
11871192
collection_id: Union[str, Parameter],
11881193
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None,
1189-
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
1190-
bands: Union[Iterable[str], Parameter, str, None] = None,
1194+
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
1195+
bands: Union[Iterable[str], Parameter, str, _FromNodeMixin, None] = None,
11911196
properties: Union[
11921197
Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None
11931198
] = None,
@@ -1287,8 +1292,10 @@ def load_result(
12871292
def load_stac(
12881293
self,
12891294
url: str,
1290-
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None,
1291-
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
1295+
spatial_extent: Union[
1296+
dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, _FromNodeMixin, None
1297+
] = None,
1298+
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
12921299
bands: Union[Iterable[str], Parameter, str, None] = None,
12931300
properties: Union[
12941301
Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None

openeo/rest/datacube.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191

9292

9393
# Type annotation aliases
94-
InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, None]
94+
InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, _FromNodeMixin, None]
9595

9696

9797
class DataCube(_ProcessGraphAbstraction):
@@ -165,8 +165,10 @@ def load_collection(
165165
cls,
166166
collection_id: Union[str, Parameter],
167167
connection: Optional[Connection] = None,
168-
spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, None] = None,
169-
temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None,
168+
spatial_extent: Union[
169+
dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, _FromNodeMixin, None
170+
] = None,
171+
temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
170172
bands: Union[Iterable[str], Parameter, str, None] = None,
171173
fetch_metadata: bool = True,
172174
properties: Union[
@@ -505,22 +507,22 @@ def _get_temporal_extent(
505507
*args,
506508
start_date: InputDate = None,
507509
end_date: InputDate = None,
508-
extent: Union[Sequence[InputDate], Parameter, str, None] = None,
509-
) -> Union[List[Union[str, Parameter, PGNode, None]], Parameter]:
510+
extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
511+
) -> Union[List[Union[str, Parameter, PGNode, _FromNodeMixin, None]], Parameter, _FromNodeMixin]:
510512
"""Parameter aware temporal_extent normalizer"""
511513
# TODO: move this outside of DataCube class
512514
# TODO: return extent as tuple instead of list
513-
if len(args) == 1 and isinstance(args[0], Parameter):
515+
if len(args) == 1 and isinstance(args[0], (Parameter, _FromNodeMixin)):
514516
assert start_date is None and end_date is None and extent is None
515517
return args[0]
516-
elif len(args) == 0 and isinstance(extent, Parameter):
518+
elif len(args) == 0 and isinstance(extent, (Parameter, _FromNodeMixin)):
517519
assert start_date is None and end_date is None
518520
# TODO: warn about unexpected parameter schema
519521
return extent
520522
else:
521523
def convertor(d: Any) -> Any:
522524
# TODO: can this be generalized through _FromNodeMixin?
523-
if isinstance(d, Parameter) or isinstance(d, PGNode):
525+
if isinstance(d, Parameter) or isinstance(d, _FromNodeMixin):
524526
# TODO: warn about unexpected parameter schema
525527
return d
526528
elif isinstance(d, ProcessBuilderBase):
@@ -556,7 +558,7 @@ def filter_temporal(
556558
*args,
557559
start_date: InputDate = None,
558560
end_date: InputDate = None,
559-
extent: Union[Sequence[InputDate], Parameter, str, None] = None,
561+
extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None,
560562
) -> DataCube:
561563
"""
562564
Limit the DataCube to a certain date range, which can be specified in several ways:

tests/rest/datacube/test_datacube.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import shapely
1919
import shapely.geometry
2020

21+
import openeo.processes
2122
from openeo import collection_property
2223
from openeo.api.process import Parameter
24+
from openeo.internal.graph_building import PGNode
2325
from openeo.metadata import SpatialDimension
2426
from openeo.rest import BandMathException, OpenEoClientException
2527
from openeo.rest._testing import build_capabilities
@@ -698,6 +700,69 @@ def test_filter_temporal_single_arg(s2cube: DataCube, arg, expect_failure):
698700
_ = s2cube.filter_temporal(arg)
699701

700702

703+
@pytest.mark.parametrize(
704+
"udf_factory",
705+
[
706+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
707+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
708+
],
709+
)
710+
def test_filter_temporal_from_udf(s2cube: DataCube, udf_factory):
711+
temporal_extent = udf_factory(data=[1, 2, 3], udf="print('hello time')", runtime="Python")
712+
cube = s2cube.filter_temporal(temporal_extent)
713+
assert get_download_graph(cube, drop_save_result=True) == {
714+
"loadcollection1": {
715+
"process_id": "load_collection",
716+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
717+
},
718+
"runudf1": {
719+
"process_id": "run_udf",
720+
"arguments": {"data": [1, 2, 3], "udf": "print('hello time')", "runtime": "Python"},
721+
},
722+
"filtertemporal1": {
723+
"process_id": "filter_temporal",
724+
"arguments": {
725+
"data": {"from_node": "loadcollection1"},
726+
"extent": {"from_node": "runudf1"},
727+
},
728+
},
729+
}
730+
731+
732+
@pytest.mark.parametrize(
733+
"udf_factory",
734+
[
735+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
736+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
737+
],
738+
)
739+
def test_filter_temporal_start_end_from_udf(s2cube: DataCube, udf_factory):
740+
start = udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python")
741+
end = udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python")
742+
cube = s2cube.filter_temporal(start_date=start, end_date=end)
743+
assert get_download_graph(cube, drop_save_result=True) == {
744+
"loadcollection1": {
745+
"process_id": "load_collection",
746+
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
747+
},
748+
"runudf1": {
749+
"process_id": "run_udf",
750+
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
751+
},
752+
"runudf2": {
753+
"process_id": "run_udf",
754+
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
755+
},
756+
"filtertemporal1": {
757+
"process_id": "filter_temporal",
758+
"arguments": {
759+
"data": {"from_node": "loadcollection1"},
760+
"extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
761+
},
762+
},
763+
}
764+
765+
701766
def test_max_time(s2cube, api_version):
702767
im = s2cube.max_time()
703768
graph = _get_leaf_node(im, force_flat=True)

tests/rest/datacube/test_datacube100.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,70 @@ def test_load_collection_parameterized_extents(con100, spatial_extent, temporal_
23752375
}
23762376

23772377

2378+
@pytest.mark.parametrize(
2379+
"udf_factory",
2380+
[
2381+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
2382+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
2383+
],
2384+
)
2385+
def test_load_collection_extents_from_udf(con100, udf_factory):
2386+
spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python")
2387+
temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python")
2388+
cube = con100.load_collection("S2", spatial_extent=spatial_extent, temporal_extent=temporal_extent)
2389+
assert get_download_graph(cube, drop_save_result=True) == {
2390+
"runudf1": {
2391+
"process_id": "run_udf",
2392+
"arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"},
2393+
},
2394+
"runudf2": {
2395+
"process_id": "run_udf",
2396+
"arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"},
2397+
},
2398+
"loadcollection1": {
2399+
"process_id": "load_collection",
2400+
"arguments": {
2401+
"id": "S2",
2402+
"spatial_extent": {"from_node": "runudf1"},
2403+
"temporal_extent": {"from_node": "runudf2"},
2404+
},
2405+
},
2406+
}
2407+
2408+
2409+
@pytest.mark.parametrize(
2410+
"udf_factory",
2411+
[
2412+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
2413+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
2414+
],
2415+
)
2416+
def test_load_collection_temporal_extent_from_udf(con100, udf_factory):
2417+
temporal_extent = [
2418+
udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"),
2419+
udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"),
2420+
]
2421+
cube = con100.load_collection("S2", temporal_extent=temporal_extent)
2422+
assert get_download_graph(cube, drop_save_result=True) == {
2423+
"runudf1": {
2424+
"process_id": "run_udf",
2425+
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
2426+
},
2427+
"runudf2": {
2428+
"process_id": "run_udf",
2429+
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
2430+
},
2431+
"loadcollection1": {
2432+
"process_id": "load_collection",
2433+
"arguments": {
2434+
"id": "S2",
2435+
"spatial_extent": None,
2436+
"temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
2437+
},
2438+
},
2439+
}
2440+
2441+
23782442
def test_apply_dimension_temporal_cumsum_with_target(con100, test_data):
23792443
cumsum = con100.load_collection("S2").apply_dimension('cumsum', dimension="t", target_dimension="MyNewTime")
23802444
actual_graph = cumsum.flat_graph()

tests/rest/test_connection.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import shapely.geometry
1818

1919
import openeo
20+
import openeo.processes
2021
from openeo import BatchJob
2122
from openeo.api.process import Parameter
2223
from openeo.internal.graph_building import FlatGraphableMixin, PGNode
@@ -3715,6 +3716,73 @@ def test_load_stac_spatial_extent_vector_cube(self, dummy_backend):
37153716
},
37163717
}
37173718

3719+
@pytest.mark.parametrize(
3720+
"udf_factory",
3721+
[
3722+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
3723+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
3724+
],
3725+
)
3726+
def test_load_stac_extents_from_udf(self, dummy_backend, udf_factory):
3727+
spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python")
3728+
temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python")
3729+
cube = dummy_backend.connection.load_stac(
3730+
"https://stac.test/data", spatial_extent=spatial_extent, temporal_extent=temporal_extent
3731+
)
3732+
cube.execute()
3733+
assert dummy_backend.get_sync_pg() == {
3734+
"runudf1": {
3735+
"process_id": "run_udf",
3736+
"arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"},
3737+
},
3738+
"runudf2": {
3739+
"process_id": "run_udf",
3740+
"arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"},
3741+
},
3742+
"loadstac1": {
3743+
"process_id": "load_stac",
3744+
"arguments": {
3745+
"url": "https://stac.test/data",
3746+
"spatial_extent": {"from_node": "runudf1"},
3747+
"temporal_extent": {"from_node": "runudf2"},
3748+
},
3749+
"result": True,
3750+
},
3751+
}
3752+
3753+
@pytest.mark.parametrize(
3754+
"udf_factory",
3755+
[
3756+
(lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)),
3757+
(lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)),
3758+
],
3759+
)
3760+
def test_load_stac_temporal_extent_from_udf(self, dummy_backend, udf_factory):
3761+
temporal_extent = [
3762+
udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"),
3763+
udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"),
3764+
]
3765+
cube = dummy_backend.connection.load_stac("https://stac.test/data", temporal_extent=temporal_extent)
3766+
cube.execute()
3767+
assert dummy_backend.get_sync_pg() == {
3768+
"runudf1": {
3769+
"process_id": "run_udf",
3770+
"arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"},
3771+
},
3772+
"runudf2": {
3773+
"process_id": "run_udf",
3774+
"arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"},
3775+
},
3776+
"loadstac1": {
3777+
"process_id": "load_stac",
3778+
"arguments": {
3779+
"url": "https://stac.test/data",
3780+
"temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}],
3781+
},
3782+
"result": True,
3783+
},
3784+
}
3785+
37183786

37193787
@pytest.mark.parametrize(
37203788
"data",

0 commit comments

Comments
 (0)