Skip to content

Commit d42e6f9

Browse files
committed
Issue #391 avoid multiple "result" modes with MultiResult flattening
Make GraphFlattener multi-input-sensitive, and be more careful about accidental mutation in this context
1 parent eb59bd0 commit d42e6f9

File tree

2 files changed

+106
-12
lines changed

2 files changed

+106
-12
lines changed

openeo/internal/graph_building.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import abc
1212
import collections
13+
import copy
1314
import json
1415
import sys
1516
from contextlib import nullcontext
@@ -322,20 +323,29 @@ def generate(self, process_id: str):
322323

323324
class GraphFlattener(ProcessGraphVisitor):
324325

325-
def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None):
326+
def __init__(self, node_id_generator: FlatGraphNodeIdGenerator = None, multi_input_mode: bool = False):
326327
super().__init__()
327328
self._node_id_generator = node_id_generator or FlatGraphNodeIdGenerator()
328329
self._last_node_id = None
329330
self._flattened: Dict[str, dict] = {}
330331
self._argument_stack = []
331332
self._node_cache = {}
333+
self._multi_input_mode = multi_input_mode
332334

333335
def flatten(self, node: PGNode) -> Dict[str, dict]:
334336
"""Consume given nested process graph and return flat dict representation"""
337+
if self._flattened and not self._multi_input_mode:
338+
raise RuntimeError("Flattening multiple graphs, but not in multi-input mode")
335339
self.accept_node(node)
336340
assert len(self._argument_stack) == 0
337-
self._flattened[self._last_node_id]["result"] = True
338-
return self._flattened
341+
return self.flattened(set_result_flag=not self._multi_input_mode)
342+
343+
def flattened(self, set_result_flag: bool = True) -> Dict[str, dict]:
344+
flat_graph = copy.deepcopy(self._flattened)
345+
if set_result_flag:
346+
# TODO #583 an "end" node is not necessarily a "result" node
347+
flat_graph[self._last_node_id]["result"] = True
348+
return flat_graph
339349

340350
def accept_node(self, node: PGNode):
341351
# Process reused nodes only first time and remember node id.
@@ -450,14 +460,13 @@ def __init__(self, leaves: List[FlatGraphableMixin]):
450460
self._leaves = leaves
451461

452462
def flat_graph(self) -> Dict[str, dict]:
453-
result = {}
454-
flattener = GraphFlattener()
463+
flattener = GraphFlattener(multi_input_mode=True)
455464
for leaf in self._leaves:
456465
if isinstance(leaf, PGNode):
457-
result = flattener.flatten(leaf)
466+
flattener.flatten(leaf)
458467
elif isinstance(leaf, _FromNodeMixin):
459-
result = flattener.flatten(leaf.from_node())
468+
flattener.flatten(leaf.from_node())
460469
else:
461-
raise ValueError(leaf)
470+
raise ValueError(f"Unsupported type {type(leaf)}")
462471

463-
return result
472+
return flattener.flattened(set_result_flag=True)

tests/internal/test_graphbuilding.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openeo.api.process import Parameter
99
from openeo.internal.graph_building import (
1010
FlatGraphNodeIdGenerator,
11+
GraphFlattener,
1112
MultiResult,
1213
PGNode,
1314
PGNodeGraphUnflattener,
@@ -146,6 +147,91 @@ def test_flat_graph_key_generate():
146147
assert g.generate("foo") == "foo3"
147148

148149

150+
class TestGraphFlattener:
151+
def test_simple(self):
152+
node = PGNode("foo", bar="meh")
153+
flattener = GraphFlattener()
154+
assert flattener.flatten(node) == {"foo1": {"process_id": "foo", "arguments": {"bar": "meh"}, "result": True}}
155+
156+
def test_chain(self):
157+
a = PGNode("a", bar="meh")
158+
b = PGNode("b", a=a)
159+
c = PGNode("c", a=a, b=b)
160+
flattener = GraphFlattener()
161+
assert flattener.flatten(c) == {
162+
"a1": {"process_id": "a", "arguments": {"bar": "meh"}},
163+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
164+
"c1": {
165+
"process_id": "c",
166+
"arguments": {"a": {"from_node": "a1"}, "b": {"from_node": "b1"}},
167+
"result": True,
168+
},
169+
}
170+
171+
def test_no_multi_input_mode(self):
172+
a = PGNode("a")
173+
b = PGNode("b", a=a)
174+
flattener = GraphFlattener()
175+
flat_graph = flattener.flatten(a)
176+
assert flat_graph == {"a1": {"process_id": "a", "arguments": {}, "result": True}}
177+
with pytest.raises(RuntimeError, match="not in multi-input mode"):
178+
flattener.flatten(b)
179+
assert flat_graph == {"a1": {"process_id": "a", "arguments": {}, "result": True}}
180+
181+
def test_multi_input_mode(self):
182+
a = PGNode("a")
183+
b = PGNode("b", a=a)
184+
c = PGNode("c", a=a)
185+
flattener = GraphFlattener(multi_input_mode=True)
186+
# Flatten b
187+
assert flattener.flatten(b) == {
188+
"a1": {"process_id": "a", "arguments": {}},
189+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
190+
}
191+
assert flattener.flattened() == {
192+
"a1": {"process_id": "a", "arguments": {}},
193+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}, "result": True},
194+
}
195+
# Flatten c
196+
assert flattener.flatten(c) == {
197+
"a1": {"process_id": "a", "arguments": {}},
198+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
199+
"c1": {"process_id": "c", "arguments": {"a": {"from_node": "a1"}}},
200+
}
201+
assert flattener.flattened() == {
202+
"a1": {"process_id": "a", "arguments": {}},
203+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
204+
"c1": {"process_id": "c", "arguments": {"a": {"from_node": "a1"}}, "result": True},
205+
}
206+
207+
def test_multi_input_mode_mutation(self):
208+
"""Verify that previously produced flat graphs are not silently mutated"""
209+
a = PGNode("a")
210+
b = PGNode("b", a=a)
211+
flattener = GraphFlattener(multi_input_mode=True)
212+
a_flat = flattener.flatten(a)
213+
assert a_flat == {
214+
"a1": {"process_id": "a", "arguments": {}},
215+
}
216+
b_flat = flattener.flatten(b)
217+
assert b_flat == {
218+
"a1": {"process_id": "a", "arguments": {}},
219+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
220+
}
221+
assert flattener.flattened() == {
222+
"a1": {"process_id": "a", "arguments": {}},
223+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}, "result": True},
224+
}
225+
# Original graphs are not mutated silently
226+
assert a_flat == {
227+
"a1": {"process_id": "a", "arguments": {}},
228+
}
229+
assert b_flat == {
230+
"a1": {"process_id": "a", "arguments": {}},
231+
"b1": {"process_id": "b", "arguments": {"a": {"from_node": "a1"}}},
232+
}
233+
234+
149235
def test_build_and_flatten_simple():
150236
node = PGNode("foo")
151237
assert node.flat_graph() == {"foo1": {"process_id": "foo", "arguments": {}, "result": True}}
@@ -421,14 +507,14 @@ class TestMultiResult:
421507
def test_simple(self):
422508
multi = MultiResult([PGNode("foo"), PGNode("bar")])
423509
assert multi.flat_graph() == {
424-
"foo1": {"process_id": "foo", "arguments": {}, "result": True},
510+
"foo1": {"process_id": "foo", "arguments": {}},
425511
"bar1": {"process_id": "bar", "arguments": {}, "result": True},
426512
}
427513

428514
def test_simple_duplicates(self):
429515
multi = MultiResult([PGNode("foo"), PGNode("foo")])
430516
assert multi.flat_graph() == {
431-
"foo1": {"process_id": "foo", "arguments": {}, "result": True},
517+
"foo1": {"process_id": "foo", "arguments": {}},
432518
"foo2": {"process_id": "foo", "arguments": {}, "result": True},
433519
}
434520

@@ -442,7 +528,6 @@ def test_multi_save_result_same_root(self):
442528
"saveresult1": {
443529
"process_id": "save_result",
444530
"arguments": {"data": {"from_node": "loadcollection1"}, "format": "GTiff", "options": {}},
445-
"result": True,
446531
},
447532
"saveresult2": {
448533
"process_id": "save_result",

0 commit comments

Comments
 (0)