Skip to content

Commit e68f70f

Browse files
committed
TestDataLoader: add verify_dict_keys feature (on by default)
1 parent f262916 commit e68f70f

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

openeo/testing/io.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
from pathlib import Path
44
from typing import Callable, Optional, Union
55

6+
from openeo.util import repr_truncate
7+
8+
9+
class PreprocessError(ValueError):
10+
pass
611

712
class TestDataLoader:
813
"""
@@ -36,20 +41,28 @@ def get_path(self, filename: Union[str, Path]) -> Path:
3641
def load_bytes(self, filename: Union[str, Path]) -> bytes:
3742
return self.get_path(filename).read_bytes()
3843

39-
def _get_preprocess(self, preprocess: Union[None, dict, Callable[[str], str]]) -> Callable[[str], str]:
44+
def _get_preprocess(
45+
self,
46+
preprocess: Union[None, dict, Callable[[str], str]],
47+
verify_dict_keys: bool = True,
48+
) -> Callable[[str], str]:
4049
"""Normalize preprocess argument to a callable"""
4150
if preprocess is None:
4251
return lambda x: x
4352
elif isinstance(preprocess, dict):
4453

4554
def replace(text: str) -> str:
46-
for key, value in preprocess.items():
47-
if isinstance(key, re.Pattern):
48-
text = key.sub(value, text)
49-
elif isinstance(key, str):
50-
text = text.replace(key, value)
55+
for needle, replacement in preprocess.items():
56+
if isinstance(needle, re.Pattern):
57+
if verify_dict_keys and not needle.search(text):
58+
raise PreprocessError(f"{needle!r} not found in {repr_truncate(text, width=265)}")
59+
text = needle.sub(repl=replacement, string=text)
60+
elif isinstance(needle, str):
61+
if verify_dict_keys and needle not in text:
62+
raise PreprocessError(f"{needle!r} not found in {repr_truncate(text, width=256)}")
63+
text = text.replace(needle, replacement)
5164
else:
52-
raise ValueError(key)
65+
raise ValueError(needle)
5366
return text
5467

5568
return replace
@@ -61,6 +74,7 @@ def load_text(
6174
filename: Union[str, Path],
6275
*,
6376
preprocess: Union[None, dict, Callable[[str], str]] = None,
77+
verify_dict_keys: bool = True,
6478
encoding: str = "utf8",
6579
) -> str:
6680
"""
@@ -74,17 +88,20 @@ def load_text(
7488
Needle can be a simple string that will be replaced with the replacement value,
7589
or it can be a ``re.Pattern`` that will be used in ``re.sub()`` style
7690
(which supports group references, e.g. "\1" for first group in match)
91+
:param verify_dict_keys: when ``preprocess`` is specified as dict:
92+
whether to verify that the keys actually exist in the text before replacing them.
7793
:param encoding: Encoding to use when reading the file
7894
"""
7995
text = self.get_path(filename).read_text(encoding=encoding)
80-
text = self._get_preprocess(preprocess)(text)
96+
text = self._get_preprocess(preprocess, verify_dict_keys=verify_dict_keys)(text)
8197
return text
8298

8399
def load_json(
84100
self,
85101
filename: Union[str, Path],
86102
*,
87103
preprocess: Union[None, dict, Callable[[str], str]] = None,
104+
verify_dict_keys: bool = True,
88105
) -> dict:
89106
"""
90107
Load data from a JSON file, optionally with some text based preprocessing
@@ -97,7 +114,9 @@ def load_json(
97114
Needle can be a simple string that will be replaced with the replacement value,
98115
or it can be a ``re.Pattern`` that will be used in ``re.sub()`` style
99116
(which supports group references, e.g. "\1" for first group in match)
117+
:param verify_dict_keys: when ``preprocess`` is specified as dict:
118+
whether to verify that the keys actually exist in the text before replacing them.
100119
"""
101120
raw = self.get_path(filename).read_text(encoding="utf8")
102-
raw = self._get_preprocess(preprocess)(raw)
121+
raw = self._get_preprocess(preprocess, verify_dict_keys=verify_dict_keys)(raw)
103122
return json.loads(raw)

tests/testing/test_io.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from openeo.testing.io import TestDataLoader
7+
from openeo.testing.io import PreprocessError, TestDataLoader
88

99

1010
class TestTestDataLoader:
@@ -40,6 +40,23 @@ def test_load_text(self, tmp_path, preprocess, expected):
4040
loader = TestDataLoader(root=tmp_path)
4141
assert loader.load_text("hello.txt", preprocess=preprocess) == expected
4242

43+
@pytest.mark.parametrize(
44+
["preprocess", "expected"],
45+
[
46+
({"apple": "banana"}, "'apple' not found in 'Hello, World!'"),
47+
({re.compile("ap+le"): "banana"}, "re.compile('ap+le') not found in 'Hello, World!'"),
48+
],
49+
)
50+
def test_load_text_missing_key(self, tmp_path, preprocess, expected):
51+
(tmp_path / "hello.txt").write_text("Hello, World!", encoding="utf8")
52+
53+
loader = TestDataLoader(root=tmp_path)
54+
with pytest.raises(PreprocessError, match=re.escape(expected)):
55+
loader.load_text("hello.txt", preprocess=preprocess)
56+
57+
# Don't fail when verify_dict_keys=False
58+
assert loader.load_text("hello.txt", preprocess=preprocess, verify_dict_keys=False) == "Hello, World!"
59+
4360
@pytest.mark.parametrize(
4461
["preprocess", "expected"],
4562
[
@@ -73,3 +90,30 @@ def test_load_json(self, tmp_path, preprocess, expected):
7390

7491
loader = TestDataLoader(root=tmp_path)
7592
assert loader.load_json("data.json", preprocess=preprocess) == expected
93+
94+
@pytest.mark.parametrize(
95+
["preprocess", "expected"],
96+
[
97+
(
98+
{"apple": "banana"},
99+
'\'apple\' not found in \'{"salutation": "Hello", "target": "World"}\'',
100+
),
101+
(
102+
{re.compile("ap+le"): "banana"},
103+
're.compile(\'ap+le\') not found in \'{"salutation": "Hello", "target": "World"}\'',
104+
),
105+
],
106+
)
107+
def test_load_json_missing_key(self, tmp_path, preprocess, expected):
108+
with (tmp_path / "data.json").open("w", encoding="utf8") as f:
109+
json.dump({"salutation": "Hello", "target": "World"}, f)
110+
111+
loader = TestDataLoader(root=tmp_path)
112+
with pytest.raises(PreprocessError, match=re.escape(expected)):
113+
loader.load_json("data.json", preprocess=preprocess)
114+
115+
# Don't fail when verify_dict_keys=False
116+
assert loader.load_json("data.json", preprocess=preprocess, verify_dict_keys=False) == {
117+
"salutation": "Hello",
118+
"target": "World",
119+
}

0 commit comments

Comments
 (0)