diff --git a/tinyquery/evaluator_test.py b/tinyquery/evaluator_test.py index 0801b8f..bf21293 100644 --- a/tinyquery/evaluator_test.py +++ b/tinyquery/evaluator_test.py @@ -1339,6 +1339,11 @@ def test_other_timestamp_functions(self): self.make_context([ ('f0_', tq_types.INT, [1262304000000000])])) + def test_replace(self): + self.assert_query_result( + "SELECT REPLACE(str, 'o', 'e') FROM string_table_with_null", + self.make_context([('f0_', tq_types.STRING, ["helle", "werld", None])])) + def test_first(self): # Test over the equivalent of a GROUP BY self.assert_query_result( diff --git a/tinyquery/runtime.py b/tinyquery/runtime.py index 5ad8922..6d7fa8c 100644 --- a/tinyquery/runtime.py +++ b/tinyquery/runtime.py @@ -1011,6 +1011,21 @@ def apply(*args): values=values) +class ReplaceFunction(ScalarFunction): + def check_types(self, *arg_types): + if any(arg_type != tq_types.STRING for arg_type in arg_types): + raise TypeError('REPLACE only takes string arguments.') + return tq_types.STRING + + def _evaluate(self, num_rows, values, old, new): + values = [value.replace(old, new) if value is not None else None + for value, old, new in zip(values.values, + old.values, + new.values)] + return context.Column(tq_types.STRING, tq_modes.NULLABLE, + values=values) + + class JSONExtractFunction(ScalarFunction): """Extract from a JSON string based on a JSONPath expression. @@ -1320,6 +1335,7 @@ def _evaluate(self, num_rows, json_expressions, json_paths): lambda dt: dt.year, return_type=tq_types.INT), TimestampFunction()), + 'replace': ReplaceFunction(), 'json_extract': JSONExtractFunction(), 'json_extract_scalar': JSONExtractFunction(scalar=True), }