Skip to content

Commit ae5dc41

Browse files
Add schema namespaces for UDFs (#1268)
## Description Allows dataset and metric UDFs to be registered in specific schema namespaces ``` @register_dataset_udf(["col1"], schema_name="bob") def bob(x): return x["col1"] @register_metric_udf("col1", schema_name="bob") def rob(x): return x @register_dataset_udf(["col1"], "add5") def fob(x): return x["col1"] + 5 default_schema = udf_schema() data = pd.DataFrame({"col1": [42, 12, 7]}) default_view = why.log(data, schema=default_schema).view() # no bob or rob bob_schema = udf_schema(schema_name="bob") data = pd.DataFrame({"col1": [42, 12, 7]}) # original data frame stomped on by previous log() UDFs bob_view = why.log(data, schema=bob_schema).view() # bob & rob, but no add5 ``` ## Related Relates to #1263 <!-- Reference related commits, issues and pull requests. Type `#` and select from the list. --> Closes [clickup task](https://app.clickup.com/t/866ab6rqn) - [ ] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md) and the [Code of Conduct](CODE_OF_CONDUCT.md).
1 parent 77635cc commit ae5dc41

File tree

3 files changed

+74
-34
lines changed

3 files changed

+74
-34
lines changed

python/tests/experimental/core/test_udf_schema.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,34 @@ def test_udf_track() -> None:
284284
assert prod_summary["counts/n"] == 3
285285
div_summary = results.get_column("ratio").to_summary_dict()
286286
assert div_summary["distribution/n"] == 3
287+
288+
289+
@register_dataset_udf(["schema.col1"], schema_name="bob")
290+
def bob(x):
291+
return x["schema.col1"]
292+
293+
294+
@register_metric_udf("schema.col1", schema_name="bob")
295+
def rob(x):
296+
return x
297+
298+
299+
@register_dataset_udf(["schema.col1"], "add5")
300+
def fob(x):
301+
return x["schema.col1"] + 5
302+
303+
304+
def test_schema_name() -> None:
305+
default_schema = udf_schema()
306+
data = pd.DataFrame({"schema.col1": [42, 12, 7]})
307+
default_view = why.log(data, schema=default_schema).view()
308+
assert "add5" in default_view.get_columns()
309+
assert "bob" not in default_view.get_columns()
310+
assert "udf" not in default_view.get_column("schema.col1").get_metric_names()
311+
312+
bob_schema = udf_schema(schema_name="bob")
313+
data = pd.DataFrame({"schema.col1": [42, 12, 7]})
314+
bob_view = why.log(data, schema=bob_schema).view()
315+
assert "add5" not in bob_view.get_columns()
316+
assert "bob" in bob_view.get_columns()
317+
assert "udf" in bob_view.get_column("schema.col1").get_metric_names()

python/whylogs/experimental/core/metrics/udf_metric.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,17 @@ def zero(cls, config: Optional[MetricConfig] = None) -> "UdfMetric":
177177
register_metric(UdfMetric)
178178

179179

180-
_col_name_submetrics: Dict[str, List[Tuple[str, Callable[[Any], Any]]]] = defaultdict(list)
181-
_col_name_submetric_schema: Dict[str, SubmetricSchema] = dict()
182-
_col_name_type_mapper: Dict[str, TypeMapper] = dict()
180+
_col_name_submetrics: Dict[str, Dict[str, List[Tuple[str, Callable[[Any], Any]]]]] = defaultdict(
181+
lambda: defaultdict(list)
182+
)
183+
_col_name_submetric_schema: Dict[str, Dict[str, SubmetricSchema]] = defaultdict(dict)
184+
_col_name_type_mapper: Dict[str, Dict[str, TypeMapper]] = defaultdict(dict)
183185

184-
_col_type_submetrics: Dict[DataType, List[Tuple[str, Callable[[Any], Any]]]] = defaultdict(list)
185-
_col_type_submetric_schema: Dict[DataType, SubmetricSchema] = dict()
186-
_col_type_type_mapper: Dict[DataType, TypeMapper] = dict()
186+
_col_type_submetrics: Dict[str, Dict[DataType, List[Tuple[str, Callable[[Any], Any]]]]] = defaultdict(
187+
lambda: defaultdict(list)
188+
)
189+
_col_type_submetric_schema: Dict[str, Dict[DataType, SubmetricSchema]] = defaultdict(dict)
190+
_col_type_type_mapper: Dict[str, Dict[DataType, TypeMapper]] = defaultdict(dict)
187191

188192

189193
def register_metric_udf(
@@ -193,6 +197,7 @@ def register_metric_udf(
193197
submetric_schema: Optional[SubmetricSchema] = None,
194198
type_mapper: Optional[TypeMapper] = None,
195199
namespace: Optional[str] = None,
200+
schema_name: str = "",
196201
) -> Callable[[Any], Any]:
197202
"""
198203
Decorator to easily configure UdfMetrics for your data set. Decorate your UDF
@@ -229,32 +234,32 @@ def decorator_register(func):
229234
subname = submetric_name or func.__name__
230235
subname = f"{namespace}.{subname}" if namespace else subname
231236
if col_name is not None:
232-
_col_name_submetrics[col_name].append((subname, func))
237+
_col_name_submetrics[schema_name][col_name].append((subname, func))
233238
if submetric_schema is not None:
234-
if col_name in _col_name_submetric_schema:
239+
if col_name in _col_name_submetric_schema[schema_name]:
235240
logger.warn(f"Overwriting submetric schema for column {col_name}")
236-
_col_name_submetric_schema[col_name] = submetric_schema
241+
_col_name_submetric_schema[schema_name][col_name] = submetric_schema
237242
if type_mapper is not None:
238-
if col_name in _col_name_type_mapper:
243+
if col_name in _col_name_type_mapper[schema_name]:
239244
logger.warn(f"Overwriting UdfMetric type mapper for column {col_name}")
240-
_col_name_type_mapper[col_name] = type_mapper
245+
_col_name_type_mapper[schema_name][col_name] = type_mapper
241246
else:
242-
_col_type_submetrics[col_type].append((subname, func))
247+
_col_type_submetrics[schema_name][col_type].append((subname, func))
243248
if submetric_schema is not None:
244-
if col_type in _col_type_submetric_schema:
249+
if col_type in _col_type_submetric_schema[schema_name]:
245250
logger.warn(f"Overwriting submetric schema for column type {col_type}")
246-
_col_type_submetric_schema[col_type] = submetric_schema
251+
_col_type_submetric_schema[schema_name][col_type] = submetric_schema
247252
if type_mapper is not None:
248-
if col_type in _col_type_type_mapper:
253+
if col_type in _col_type_type_mapper[schema_name]:
249254
logger.warn(f"Overwriting UdfMetric type mapper for column type {col_type}")
250-
_col_type_type_mapper[col_type] = type_mapper
255+
_col_type_type_mapper[schema_name][col_type] = type_mapper
251256

252257
return func
253258

254259
return decorator_register
255260

256261

257-
def generate_udf_resolvers() -> List[ResolverSpec]:
262+
def generate_udf_resolvers(schema_name: str = "") -> List[ResolverSpec]:
258263
"""
259264
Generates a list of ResolverSpecs that implement the UdfMetrics specified
260265
by the @register_metric_udf decorators. The result only includes the UdfMetric,
@@ -283,25 +288,25 @@ def upper(x):
283288

284289
resolvers: List[ResolverSpec] = list()
285290
udfs: Dict[str, Callable[[Any], Any]]
286-
for col_name, submetrics in _col_name_submetrics.items():
291+
for col_name, submetrics in _col_name_submetrics[schema_name].items():
287292
udfs = dict()
288293
for submetric in submetrics:
289294
udfs[submetric[0]] = submetric[1]
290295
config = UdfMetricConfig(
291296
udfs=udfs,
292-
submetric_schema=_col_name_submetric_schema.get(col_name) or default_schema(),
293-
type_mapper=_col_name_type_mapper.get(col_name) or StandardTypeMapper(),
297+
submetric_schema=_col_name_submetric_schema[schema_name].get(col_name) or default_schema(),
298+
type_mapper=_col_name_type_mapper[schema_name].get(col_name) or StandardTypeMapper(),
294299
)
295300
resolvers.append(ResolverSpec(col_name, None, [MetricSpec(UdfMetric, config)]))
296301

297-
for col_type, submetrics in _col_type_submetrics.items():
302+
for col_type, submetrics in _col_type_submetrics[schema_name].items():
298303
udfs = dict()
299304
for submetric in submetrics:
300305
udfs[submetric[0]] = submetric[1]
301306
config = UdfMetricConfig(
302307
udfs=udfs,
303-
submetric_schema=_col_type_submetric_schema.get(col_type) or default_schema(),
304-
type_mapper=_col_type_type_mapper.get(col_type) or StandardTypeMapper(),
308+
submetric_schema=_col_type_submetric_schema[schema_name].get(col_type) or default_schema(),
309+
type_mapper=_col_type_type_mapper[schema_name].get(col_type) or StandardTypeMapper(),
305310
)
306311
resolvers.append(ResolverSpec(None, col_type, [MetricSpec(UdfMetric, config)]))
307312

@@ -322,6 +327,7 @@ def udf_metric_schema(
322327
schema_based_automerge: bool = False,
323328
segments: Optional[Dict[str, SegmentationPartition]] = None,
324329
validators: Optional[Dict[str, List[Validator]]] = None,
330+
schema_name: str = "",
325331
) -> DeclarativeSchema:
326332
"""
327333
Generates a DeclarativeSchema that implement the UdfMetrics specified
@@ -347,7 +353,7 @@ def upper(x):
347353
STANDARD_RESOLVER, the default metrics are also tracked for every column.
348354
"""
349355

350-
resolvers = generate_udf_resolvers()
356+
resolvers = generate_udf_resolvers(schema_name)
351357
non_udf_resolvers = non_udf_resolvers if non_udf_resolvers is not None else UDF_BASE_RESOLVER
352358

353359
return DeclarativeSchema(

python/whylogs/experimental/core/udf_schema.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections import defaultdict
23
from copy import deepcopy
34
from dataclasses import dataclass, field
45
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
@@ -112,15 +113,16 @@ def _run_udfs(
112113
return new_df, new_columns
113114

114115

115-
_multicolumn_udfs: List[UdfSpec] = []
116-
_resolver_specs: List[ResolverSpec] = []
116+
_multicolumn_udfs: Dict[str, List[UdfSpec]] = defaultdict(list)
117+
_resolver_specs: Dict[str, List[ResolverSpec]] = defaultdict(list)
117118

118119

119120
def register_dataset_udf(
120121
col_names: List[str],
121122
udf_name: Optional[str] = None,
122123
metrics: Optional[List[MetricSpec]] = None,
123124
namespace: Optional[str] = None,
125+
schema_name: str = "",
124126
) -> Callable[[Any], Any]:
125127
"""
126128
Decorator to easily configure UDFs for your data set. Decorate your UDF
@@ -144,16 +146,16 @@ def decorator_register(func):
144146
global _multicolumn_udfs, _resolver_specs
145147
name = udf_name or func.__name__
146148
name = f"{namespace}.{name}" if namespace else name
147-
_multicolumn_udfs.append(UdfSpec(col_names, {name: func}))
149+
_multicolumn_udfs[schema_name].append(UdfSpec(col_names, {name: func}))
148150
if metrics:
149-
_resolver_specs.append(ResolverSpec(name, None, deepcopy(metrics)))
151+
_resolver_specs[schema_name].append(ResolverSpec(name, None, deepcopy(metrics)))
150152

151153
return func
152154

153155
return decorator_register
154156

155157

156-
def generate_udf_specs(other_udf_specs: Optional[List[UdfSpec]] = None) -> List[UdfSpec]:
158+
def generate_udf_specs(other_udf_specs: Optional[List[UdfSpec]] = None, schema_name: str = "") -> List[UdfSpec]:
157159
"""
158160
Generates a list UdfSpecs that implement the UDFs specified
159161
by the @register_dataset_udf decorators. You can provide a list of
@@ -175,7 +177,7 @@ def add5(x):
175177
for every column.
176178
"""
177179
specs = list(other_udf_specs) if other_udf_specs else []
178-
specs += _multicolumn_udfs
180+
specs += _multicolumn_udfs[schema_name]
179181
return specs
180182

181183

@@ -189,16 +191,17 @@ def udf_schema(
189191
schema_based_automerge: bool = False,
190192
segments: Optional[Dict[str, SegmentationPartition]] = None,
191193
validators: Optional[Dict[str, List[Validator]]] = None,
194+
schema_name: str = "",
192195
) -> UdfSchema:
193196
"""
194197
Returns a UdfSchema that implements any registered UDFs, along with any
195198
other_udf_specs or resolvers passed in.
196199
"""
197200
if resolvers is not None:
198-
resolver_specs = resolvers + _resolver_specs
201+
resolver_specs = resolvers + _resolver_specs[schema_name]
199202
else:
200-
resolver_specs = UDF_BASE_RESOLVER + _resolver_specs
201-
resolver_specs += generate_udf_resolvers()
203+
resolver_specs = UDF_BASE_RESOLVER + _resolver_specs[schema_name]
204+
resolver_specs += generate_udf_resolvers(schema_name)
202205
return UdfSchema(
203206
resolver_specs,
204207
types,
@@ -208,5 +211,5 @@ def udf_schema(
208211
schema_based_automerge,
209212
segments,
210213
validators,
211-
generate_udf_specs(other_udf_specs),
214+
generate_udf_specs(other_udf_specs, schema_name),
212215
)

0 commit comments

Comments
 (0)