@@ -177,13 +177,17 @@ def zero(cls, config: Optional[MetricConfig] = None) -> "UdfMetric":
177177register_metric (UdfMetric )
178178
179179
180- _col_name_submetrics : Dict [str , List [Tuple [str , Callable [[Any ], Any ]]]] = defaultdict (list )
181- _col_name_submetric_schema : Dict [str , SubmetricSchema ] = dict ()
182- _col_name_type_mapper : Dict [str , TypeMapper ] = dict ()
180+ _col_name_submetrics : Dict [str , Dict [str , List [Tuple [str , Callable [[Any ], Any ]]]]] = defaultdict (
181+ lambda : defaultdict (list )
182+ )
183+ _col_name_submetric_schema : Dict [str , Dict [str , SubmetricSchema ]] = defaultdict (dict )
184+ _col_name_type_mapper : Dict [str , Dict [str , TypeMapper ]] = defaultdict (dict )
183185
184- _col_type_submetrics : Dict [DataType , List [Tuple [str , Callable [[Any ], Any ]]]] = defaultdict (list )
185- _col_type_submetric_schema : Dict [DataType , SubmetricSchema ] = dict ()
186- _col_type_type_mapper : Dict [DataType , TypeMapper ] = dict ()
186+ _col_type_submetrics : Dict [str , Dict [DataType , List [Tuple [str , Callable [[Any ], Any ]]]]] = defaultdict (
187+ lambda : defaultdict (list )
188+ )
189+ _col_type_submetric_schema : Dict [str , Dict [DataType , SubmetricSchema ]] = defaultdict (dict )
190+ _col_type_type_mapper : Dict [str , Dict [DataType , TypeMapper ]] = defaultdict (dict )
187191
188192
189193def register_metric_udf (
@@ -193,6 +197,7 @@ def register_metric_udf(
193197 submetric_schema : Optional [SubmetricSchema ] = None ,
194198 type_mapper : Optional [TypeMapper ] = None ,
195199 namespace : Optional [str ] = None ,
200+ schema_name : str = "" ,
196201) -> Callable [[Any ], Any ]:
197202 """
198203 Decorator to easily configure UdfMetrics for your data set. Decorate your UDF
@@ -229,32 +234,32 @@ def decorator_register(func):
229234 subname = submetric_name or func .__name__
230235 subname = f"{ namespace } .{ subname } " if namespace else subname
231236 if col_name is not None :
232- _col_name_submetrics [col_name ].append ((subname , func ))
237+ _col_name_submetrics [schema_name ][ col_name ].append ((subname , func ))
233238 if submetric_schema is not None :
234- if col_name in _col_name_submetric_schema :
239+ if col_name in _col_name_submetric_schema [ schema_name ] :
235240 logger .warn (f"Overwriting submetric schema for column { col_name } " )
236- _col_name_submetric_schema [col_name ] = submetric_schema
241+ _col_name_submetric_schema [schema_name ][ col_name ] = submetric_schema
237242 if type_mapper is not None :
238- if col_name in _col_name_type_mapper :
243+ if col_name in _col_name_type_mapper [ schema_name ] :
239244 logger .warn (f"Overwriting UdfMetric type mapper for column { col_name } " )
240- _col_name_type_mapper [col_name ] = type_mapper
245+ _col_name_type_mapper [schema_name ][ col_name ] = type_mapper
241246 else :
242- _col_type_submetrics [col_type ].append ((subname , func ))
247+ _col_type_submetrics [schema_name ][ col_type ].append ((subname , func ))
243248 if submetric_schema is not None :
244- if col_type in _col_type_submetric_schema :
249+ if col_type in _col_type_submetric_schema [ schema_name ] :
245250 logger .warn (f"Overwriting submetric schema for column type { col_type } " )
246- _col_type_submetric_schema [col_type ] = submetric_schema
251+ _col_type_submetric_schema [schema_name ][ col_type ] = submetric_schema
247252 if type_mapper is not None :
248- if col_type in _col_type_type_mapper :
253+ if col_type in _col_type_type_mapper [ schema_name ] :
249254 logger .warn (f"Overwriting UdfMetric type mapper for column type { col_type } " )
250- _col_type_type_mapper [col_type ] = type_mapper
255+ _col_type_type_mapper [schema_name ][ col_type ] = type_mapper
251256
252257 return func
253258
254259 return decorator_register
255260
256261
257- def generate_udf_resolvers () -> List [ResolverSpec ]:
262+ def generate_udf_resolvers (schema_name : str = "" ) -> List [ResolverSpec ]:
258263 """
259264 Generates a list of ResolverSpecs that implement the UdfMetrics specified
260265 by the @register_metric_udf decorators. The result only includes the UdfMetric,
@@ -283,25 +288,25 @@ def upper(x):
283288
284289 resolvers : List [ResolverSpec ] = list ()
285290 udfs : Dict [str , Callable [[Any ], Any ]]
286- for col_name , submetrics in _col_name_submetrics .items ():
291+ for col_name , submetrics in _col_name_submetrics [ schema_name ] .items ():
287292 udfs = dict ()
288293 for submetric in submetrics :
289294 udfs [submetric [0 ]] = submetric [1 ]
290295 config = UdfMetricConfig (
291296 udfs = udfs ,
292- submetric_schema = _col_name_submetric_schema .get (col_name ) or default_schema (),
293- type_mapper = _col_name_type_mapper .get (col_name ) or StandardTypeMapper (),
297+ submetric_schema = _col_name_submetric_schema [ schema_name ] .get (col_name ) or default_schema (),
298+ type_mapper = _col_name_type_mapper [ schema_name ] .get (col_name ) or StandardTypeMapper (),
294299 )
295300 resolvers .append (ResolverSpec (col_name , None , [MetricSpec (UdfMetric , config )]))
296301
297- for col_type , submetrics in _col_type_submetrics .items ():
302+ for col_type , submetrics in _col_type_submetrics [ schema_name ] .items ():
298303 udfs = dict ()
299304 for submetric in submetrics :
300305 udfs [submetric [0 ]] = submetric [1 ]
301306 config = UdfMetricConfig (
302307 udfs = udfs ,
303- submetric_schema = _col_type_submetric_schema .get (col_type ) or default_schema (),
304- type_mapper = _col_type_type_mapper .get (col_type ) or StandardTypeMapper (),
308+ submetric_schema = _col_type_submetric_schema [ schema_name ] .get (col_type ) or default_schema (),
309+ type_mapper = _col_type_type_mapper [ schema_name ] .get (col_type ) or StandardTypeMapper (),
305310 )
306311 resolvers .append (ResolverSpec (None , col_type , [MetricSpec (UdfMetric , config )]))
307312
@@ -322,6 +327,7 @@ def udf_metric_schema(
322327 schema_based_automerge : bool = False ,
323328 segments : Optional [Dict [str , SegmentationPartition ]] = None ,
324329 validators : Optional [Dict [str , List [Validator ]]] = None ,
330+ schema_name : str = "" ,
325331) -> DeclarativeSchema :
326332 """
327333 Generates a DeclarativeSchema that implement the UdfMetrics specified
@@ -347,7 +353,7 @@ def upper(x):
347353 STANDARD_RESOLVER, the default metrics are also tracked for every column.
348354 """
349355
350- resolvers = generate_udf_resolvers ()
356+ resolvers = generate_udf_resolvers (schema_name )
351357 non_udf_resolvers = non_udf_resolvers if non_udf_resolvers is not None else UDF_BASE_RESOLVER
352358
353359 return DeclarativeSchema (
0 commit comments