Skip to content

Commit 926e0fa

Browse files
authored
Fix UDF deserialization in column profiles to support UDFs on pyspark (#1271)
## Description udf metrics are multimetrics but have a different init signature which leads to problems in deserializing them in pyspark. We need to have a consistent construction that preserves the udfs submetrics on deserialization, and works in the pyspark where a column profile is deserialized from bytes. ## Changes - support multimetric init shape in udf metric - add tests to cover missing deserialization path - [x] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md) and the [Code of Conduct](CODE_OF_CONDUCT.md).
1 parent ae5dc41 commit 926e0fa

File tree

7 files changed

+81
-18
lines changed

7 files changed

+81
-18
lines changed

python/tests/api/logger/test_segments.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
import tempfile
44
from glob import glob
5+
from logging import getLogger
56
from typing import Any
67

78
import numpy as np
@@ -26,6 +27,8 @@
2627
from whylogs.core.view.dataset_profile_view import DatasetProfileView
2728
from whylogs.migration.converters import read_v0_to_view
2829

30+
TEST_LOGGER = getLogger(__name__)
31+
2932

3033
def test_single_row_segment() -> None:
3134
segment_column = "col3"
@@ -282,8 +285,8 @@ def test_multi_column_segment_serialization_roundtrip_v0(tmp_path: Any) -> None:
282285
for file_path in paths:
283286
roundtrip_profiles.append(read_v0_to_view(file_path))
284287
assert len(roundtrip_profiles) == input_rows
285-
print(roundtrip_profiles)
286-
print(roundtrip_profiles[15])
288+
TEST_LOGGER.info(roundtrip_profiles)
289+
TEST_LOGGER.info(roundtrip_profiles[15])
287290

288291
post_deserialization_view = roundtrip_profiles[15]
289292
assert post_deserialization_view is not None

python/tests/api/pyspark/experimental/test_profiler_function.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,24 @@ def resolve(self, name, why_type, column_schema):
174174
assert profile_view.get_column("2").get_metric_names() == []
175175
assert profile_view.get_column("3").get_metric_names() == []
176176

177+
def test_collect_dataset_profile_view_with_udf_schema(self, input_df):
178+
from whylogs.core.datatypes import Fractional
179+
from whylogs.experimental.core.metrics.udf_metric import (
180+
generate_udf_resolvers,
181+
register_metric_udf,
182+
)
183+
184+
def frob(x):
185+
return x * x
186+
187+
register_metric_udf(col_type=Fractional, submetric_name="square", schema_name="pyspark_test")(frob)
188+
test_schema = DeclarativeSchema(resolvers=generate_udf_resolvers(schema_name="pyspark_test"))
189+
profile_view = collect_dataset_profile_view(input_df=input_df, schema=test_schema)
190+
191+
assert isinstance(profile_view, DatasetProfileView)
192+
assert len(profile_view.get_columns()) > 0
193+
assert "udf" in profile_view.get_column("0").get_metric_names()
194+
177195
def test_map_vectors(self, embeddings_df):
178196
from pyspark.ml.functions import vector_to_array
179197

python/tests/core/view/test_dataset_profile_view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ def test_segmented_round_trip_metadata(tmp_path: str) -> None:
233233
segment_column = "A"
234234
results = why.log(df, schema=DatasetSchema(segments=segment_on_column(segment_column)))
235235
status = results.writer("local").write(dest=output_file)
236-
TEST_LOGGER.info("serialized segmented profile to {output_file}" f" has status: {status}")
236+
TEST_LOGGER.info(f"serialized segmented profile to {output_file} has status: {status}")
237237

238238
view = DatasetProfileView.read(output_file)
239-
TEST_LOGGER.info("round trip serialized and deserialized segmented profile" f" has metadata: {view._metadata}")
239+
TEST_LOGGER.info(f"round trip serialized and deserialized segmented profile has metadata: {view._metadata}")
240240
assert view._metadata is not None
241241
segment_tag_metadata_key = _TAG_PREFIX + segment_column
242242
assert segment_tag_metadata_key in view._metadata

python/tests/experimental/core/metrics/test_udf_metric.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from logging import getLogger
2+
13
import pandas as pd
24

35
import whylogs as why
@@ -10,6 +12,8 @@
1012
udf_metric_schema,
1113
)
1214

15+
logger = getLogger(__name__)
16+
1317

1418
def test_udf_metric() -> None:
1519
config = UdfMetricConfig(
@@ -39,6 +43,37 @@ def test_udf_metric() -> None:
3943
assert "foo:frequent_items/frequent_strings" in summary
4044

4145

46+
def test_udf_metric_from_to_protobuf() -> None:
47+
config = UdfMetricConfig(
48+
udfs={
49+
"fortytwo": lambda x: 42,
50+
"foo": lambda x: "bar",
51+
},
52+
)
53+
metric = UdfMetric.zero(config)
54+
metric.columnar_update(PreprocessedColumn.apply([0]))
55+
udf_metric_message = metric.to_protobuf()
56+
logger.debug(f"serialized msg: {udf_metric_message}")
57+
deserialied_udf_metric = UdfMetric.from_protobuf(udf_metric_message)
58+
summary = deserialied_udf_metric.to_summary_dict()
59+
logger.debug(f"serialized summary {summary}")
60+
assert summary["fortytwo:counts/n"] == 1
61+
assert summary["fortytwo:types/integral"] == 1
62+
assert summary["fortytwo:types/string"] == 0
63+
assert summary["fortytwo:cardinality/est"] == 1
64+
assert summary["fortytwo:distribution/n"] == 1
65+
assert summary["fortytwo:distribution/mean"] == 42
66+
assert summary["fortytwo:ints/max"] == 42
67+
assert summary["fortytwo:ints/min"] == 42
68+
assert "fortytwo:frequent_items/frequent_strings" in summary
69+
70+
assert summary["foo:counts/n"] == 1
71+
assert summary["foo:types/integral"] == 0
72+
assert summary["foo:types/string"] == 1
73+
assert summary["foo:cardinality/est"] == 1
74+
assert "foo:frequent_items/frequent_strings" in summary
75+
76+
4277
def test_udf_throws() -> None:
4378
n = 0
4479

python/tests/experimental/core/test_udf_schema.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def test_multicolumn_udf_pandas() -> None:
135135
assert "counts/n" in sqr_summary
136136
div_summary = results.get_column("ratio").to_summary_dict()
137137
assert div_summary["distribution/n"] == 3
138-
assert len(results.get_column("ratio").get_metrics()) == 2 # Integral -> counts plus registered distribution
138+
# Integral -> counts plus registered distribution
139+
assert results.get_column("ratio").get_metric("counts") is not None
140+
assert results.get_column("ratio").get_metric("distribution") is not None
139141

140142

141143
def test_multicolumn_udf_row() -> None:
@@ -172,7 +174,9 @@ def test_multicolumn_udf_row() -> None:
172174
assert "counts/n" in sqr_summary
173175
div_summary = results.get_column("ratio").to_summary_dict()
174176
assert div_summary["distribution/n"] == 1
175-
assert len(results.get_column("ratio").get_metrics()) == 2 # Integral -> counts plus registered distribution
177+
# Integral -> counts plus registered distribution
178+
assert results.get_column("ratio").get_metric("counts") is not None
179+
assert results.get_column("ratio").get_metric("distribution") is not None
176180

177181

178182
n: int = 0

python/whylogs/core/view/column_profile_view.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def from_protobuf(cls, msg: ColumnMessage) -> "ColumnProfileView":
129129

130130
c_key = full_path[len(metric_name) + 1 :]
131131
metric_components[c_key] = c_msg
132-
133132
for m_name, metric_components in metric_messages.items():
134133
m_enum = StandardMetric.__members__.get(m_name)
135134
if m_enum is None:
@@ -142,9 +141,11 @@ def from_protobuf(cls, msg: ColumnMessage) -> "ColumnProfileView":
142141

143142
m_msg = MetricMessage(metric_components=metric_components)
144143
try:
145-
result_metrics[m_name] = metric_class.from_protobuf(m_msg)
146-
except: # noqa
147-
raise DeserializationError(f"Failed to deserialize metric: {m_name}")
144+
deserialized_metric = metric_class.from_protobuf(m_msg)
145+
result_metrics[m_name] = deserialized_metric
146+
except Exception as error: # noqa
147+
raise DeserializationError(f"Failed to deserialize metric: {m_name}:{error}")
148+
148149
return ColumnProfileView(metrics=result_metrics)
149150

150151
@classmethod

python/whylogs/experimental/core/metrics/udf_metric.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,15 @@ class UdfMetric(MultiMetric):
103103

104104
def __init__(
105105
self,
106-
udfs: Dict[str, Callable[[Any], Any]],
106+
submetrics: Dict[str, Dict[str, Metric]],
107+
udfs: Optional[Dict[str, Callable[[Any], Any]]] = None,
107108
# discover these with resolver submetrics: Dict[str, Dict[str, Metric]], # feature name -> (namespace -> metric)
108109
submetric_schema: Optional[SubmetricSchema] = None,
109110
type_mapper: Optional[TypeMapper] = None,
110111
fi_disabled: bool = False,
111112
):
112-
super().__init__(dict()) # submetrics)
113-
self._udfs = udfs
113+
super().__init__(submetrics)
114+
self._udfs = udfs or dict()
114115
self._submetric_schema = submetric_schema or default_schema()
115116
self._type_mapper = type_mapper or StandardTypeMapper()
116117
self._fi_disabled = fi_disabled
@@ -120,7 +121,7 @@ def namespace(self) -> str:
120121
return "udf"
121122

122123
def merge(self, other: "UdfMetric") -> "UdfMetric":
123-
merged = UdfMetric(self._udfs, self._submetric_schema, self._type_mapper, self._fi_disabled)
124+
merged = UdfMetric(self.submetrics, self._udfs, self._submetric_schema, self._type_mapper, self._fi_disabled)
124125
merged.submetrics = self.merge_submetrics(other)
125126
return merged
126127

@@ -166,10 +167,11 @@ def zero(cls, config: Optional[MetricConfig] = None) -> "UdfMetric":
166167
config = UdfMetricConfig()
167168

168169
return UdfMetric(
169-
config.udfs,
170-
config.submetric_schema,
171-
config.type_mapper,
172-
config.fi_disabled,
170+
dict(),
171+
udfs=config.udfs,
172+
submetric_schema=config.submetric_schema,
173+
type_mapper=config.type_mapper,
174+
fi_disabled=config.fi_disabled,
173175
)
174176

175177

0 commit comments

Comments
 (0)