-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Fix Nested Metrics Handling in CompileMetrics #21761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,6 +175,40 @@ def variables(self): | |
| return vars | ||
|
|
||
| def build(self, y_true, y_pred): | ||
| # Handle nested structures using tree utilities similar to CompileLoss | ||
| if tree.is_nested(y_true) and tree.is_nested(y_pred): | ||
| try: | ||
| # Align the metrics configuration with the structure of y_pred | ||
| if self._has_nested_structure(self._user_metrics) or self._has_nested_structure(self._user_weighted_metrics): | ||
| self._flat_metrics = self._build_nested_metrics( | ||
| self._user_metrics, y_true, y_pred, "metrics" | ||
| ) | ||
| self._flat_weighted_metrics = self._build_nested_metrics( | ||
| self._user_weighted_metrics, y_true, y_pred, "weighted_metrics" | ||
| ) | ||
| else: | ||
| self._build_with_flat_structure(y_true, y_pred) | ||
| except (ValueError, TypeError): | ||
| self._build_with_flat_structure(y_true, y_pred) | ||
| else: | ||
| self._build_with_flat_structure(y_true, y_pred) | ||
| self.built = True | ||
|
|
||
| def _has_nested_structure(self, obj): | ||
| """Helper method to check if object has nested dict/list structure.""" | ||
| if obj is None: | ||
| return False | ||
| if isinstance(obj, dict): | ||
| for value in obj.values(): | ||
| if isinstance(value, (dict, list)): | ||
| return True | ||
| elif isinstance(obj, list): | ||
| for item in obj: | ||
| if isinstance(item, (dict, list)): | ||
| return True | ||
| return False | ||
|
|
||
| def _build_with_flat_structure(self, y_true, y_pred): | ||
| num_outputs = 1 # default | ||
| # Resolve output names. If y_pred is a dict, prefer its keys. | ||
| if isinstance(y_pred, dict): | ||
|
|
@@ -219,7 +253,193 @@ def build(self, y_true, y_pred): | |
| y_pred, | ||
| argument_name="weighted_metrics", | ||
| ) | ||
| self.built = True | ||
|
|
||
| def _build_nested_metrics(self, metrics_config, y_true, y_pred, argument_name): | ||
| """Build metrics for nested structures following y_pred structure.""" | ||
| if metrics_config is None: | ||
| # If metrics_config is None, create None placeholders for each output | ||
| return self._build_flat_placeholders(y_true, y_pred) | ||
|
|
||
| if (isinstance(metrics_config, dict) and | ||
| isinstance(y_pred, dict) and | ||
| set(metrics_config.keys()).issubset(set(y_pred.keys())) and | ||
| not any(tree.is_nested(v) for v in y_pred.values())): | ||
|
|
||
| return self._build_metrics_set_for_nested(metrics_config, y_true, y_pred, argument_name) | ||
|
|
||
| # Handle metrics configuration with tree structure similar to y_pred | ||
| def build_recursive_metrics(metrics_cfg, yt, yp, path=(), is_nested_path=False): | ||
| """Recursively build metrics for nested structures.""" | ||
| if isinstance(metrics_cfg, dict) and isinstance(yp, dict): | ||
| # Both metrics and predictions are dicts, process recursively | ||
| flat_metrics = [] | ||
| for key in yp.keys(): | ||
| current_path = path + (key,) | ||
| if key in metrics_cfg: | ||
| if isinstance(yp[key], dict) and isinstance(metrics_cfg[key], dict): | ||
| flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True)) | ||
| elif isinstance(yp[key], (list, tuple)) and isinstance(metrics_cfg[key], (list, tuple)): | ||
| flat_metrics.extend(build_recursive_metrics(metrics_cfg[key], yt[key], yp[key], current_path, True)) | ||
| else: | ||
| output_name = "_".join(map(str, current_path)) if is_nested_path else None | ||
| flat_metrics.append(self._build_single_output_metrics(metrics_cfg[key], yt[key], yp[key], argument_name, output_name=output_name)) | ||
| else: | ||
| flat_metrics.append(None) | ||
| return flat_metrics | ||
| elif isinstance(metrics_cfg, (list, tuple)) and isinstance(yp, (list, tuple)): | ||
|
|
||
| flat_metrics = [] | ||
| for i, (m_cfg, y_t_elem, y_p_elem) in enumerate(zip(metrics_cfg, yt, yp)): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When handling lists/tuples in |
||
| current_path = path + (i,) | ||
| if isinstance(y_p_elem, (dict, list, tuple)) and isinstance(m_cfg, (dict, list, tuple)): | ||
| flat_metrics.extend(build_recursive_metrics(m_cfg, y_t_elem, y_p_elem, current_path, True)) | ||
| else: | ||
| output_name = "_".join(map(str, current_path)) if is_nested_path else None | ||
| flat_metrics.append(self._build_single_output_metrics(m_cfg, y_t_elem, y_p_elem, argument_name, output_name=output_name)) | ||
| return flat_metrics | ||
| else: | ||
| output_name = "_".join(map(str, path)) if path and is_nested_path else None | ||
| return [self._build_single_output_metrics(metrics_cfg, yt, yp, argument_name, output_name=output_name)] | ||
|
|
||
| # For truly complex nested structures, use recursive approach | ||
| return build_recursive_metrics(metrics_config, y_true, y_pred) | ||
|
|
||
| def _build_single_output_metrics(self, metric_config, y_true, y_pred, argument_name, output_name=None): | ||
| """Build metrics for a single output.""" | ||
| if metric_config is None: | ||
| return None | ||
| elif not isinstance(metric_config, list): | ||
| metric_config = [metric_config] | ||
| if not all(is_function_like(m) for m in metric_config): | ||
| raise ValueError( | ||
| f"All entries in the sublists of the " | ||
| f"`{argument_name}` structure should be metric objects. " | ||
| f"Found the following with unknown types: {metric_config}" | ||
| ) | ||
| return MetricsList( | ||
| [ | ||
| get_metric(m, y_true, y_pred) | ||
| for m in metric_config | ||
| if m is not None | ||
| ], | ||
| output_name=output_name | ||
| ) | ||
|
|
||
| def _build_flat_placeholders(self, y_true, y_pred): | ||
| """Create None placeholders for each output when config is None.""" | ||
| flat_y_pred = tree.flatten(y_pred) | ||
| return [None] * len(flat_y_pred) | ||
|
|
||
| def _build_metrics_set_for_nested(self, metrics, y_true, y_pred, argument_name): | ||
| """Alternative method to build metrics when we detect nested structures.""" | ||
| flat_y_pred = tree.flatten(y_pred) | ||
| flat_y_true = tree.flatten(y_true) | ||
|
|
||
| if isinstance(y_pred, dict): | ||
| flat_output_names = tree.flatten(y_pred) | ||
| output_names = self._flatten_dict_keys(y_pred) | ||
| else: | ||
| output_names = [None] * len(flat_y_pred) if self.output_names is None else self.output_names | ||
|
|
||
| # If metrics is a flat dict that should map to the outputs | ||
| if isinstance(metrics, dict): | ||
| flat_metrics = [] | ||
| if isinstance(y_pred, dict): | ||
| # Map metrics dict to y_pred dict keys | ||
| for idx, (name, yt, yp) in enumerate(zip(y_pred.keys(), flat_y_true, flat_y_pred)): | ||
| if name in metrics: | ||
| metric_list = metrics[name] | ||
| if not isinstance(metric_list, list): | ||
| metric_list = [metric_list] | ||
| if not all(is_function_like(e) for e in metric_list): | ||
| raise ValueError( | ||
| f"All entries in the sublists of the " | ||
| f"`{argument_name}` dict should be metric objects. " | ||
| f"At key '{name}', found the following with unknown types: {metric_list}" | ||
| ) | ||
| flat_metrics.append( | ||
| MetricsList( | ||
| [ | ||
| get_metric(m, yt, yp) | ||
| for m in metric_list | ||
| if m is not None | ||
| ], | ||
| output_name=name, | ||
| ) | ||
| ) | ||
| else: | ||
| flat_metrics.append(None) | ||
| else: | ||
| return self._build_metrics_set(metrics, len(flat_y_pred), output_names, flat_y_true, flat_y_pred, argument_name) | ||
| elif isinstance(metrics, (list, tuple)): | ||
| # Handle list/tuple case for nested outputs | ||
| if len(metrics) != len(flat_y_pred): | ||
| raise ValueError( | ||
| f"For a model with multiple outputs, " | ||
| f"when providing the `{argument_name}` argument as a " | ||
| f"list, it should have as many entries as the model has " | ||
| f"outputs. Received:\n{argument_name}={metrics}\nof " | ||
| f"length {len(metrics)} whereas the model has " | ||
| f"{len(flat_y_pred)} outputs." | ||
| ) | ||
| flat_metrics = [] | ||
| for idx, (mls, yt, yp) in enumerate(zip(metrics, flat_y_true, flat_y_pred)): | ||
| if not isinstance(mls, list): | ||
| mls = [mls] | ||
| name = output_names[idx] if output_names and idx < len(output_names) else None | ||
| if not all(is_function_like(e) for e in mls): | ||
| raise ValueError( | ||
| f"All entries in the sublists of the " | ||
| f"`{argument_name}` list should be metric objects. " | ||
| f"Found the following sublist with unknown types: {mls}" | ||
| ) | ||
| flat_metrics.append( | ||
| MetricsList( | ||
| [ | ||
| get_metric(m, yt, yp) | ||
| for m in mls | ||
| if m is not None | ||
| ], | ||
| output_name=name, | ||
| ) | ||
| ) | ||
| else: | ||
| # Handle single metric applied to all outputs | ||
| flat_metrics = [] | ||
| for idx, (yt, yp) in enumerate(zip(flat_y_true, flat_y_pred)): | ||
| name = output_names[idx] if output_names and idx < len(output_names) else None | ||
| if metrics is None: | ||
| flat_metrics.append(None) | ||
| else: | ||
| if not is_function_like(metrics): | ||
| raise ValueError( | ||
| f"Expected all entries in the `{argument_name}` list " | ||
| f"to be metric objects. Received instead:\n" | ||
| f"{argument_name}={metrics}" | ||
| ) | ||
| flat_metrics.append( | ||
| MetricsList( | ||
| [get_metric(metrics, yt, yp)], | ||
| output_name=name, | ||
| ) | ||
| ) | ||
|
|
||
| return flat_metrics | ||
|
|
||
| def _flatten_dict_keys(self, d): | ||
| """Flatten dict to get key names in order.""" | ||
| if isinstance(d, dict): | ||
| return list(d.keys()) | ||
| elif isinstance(d, (list, tuple)): | ||
| result = [] | ||
| for item in d: | ||
| if isinstance(item, dict): | ||
| result.extend(list(item.keys())) | ||
| else: | ||
| result.append(None) | ||
| return result | ||
| else: | ||
| return [None] | ||
|
|
||
| def _build_metrics_set( | ||
| self, metrics, num_outputs, output_names, y_true, y_pred, argument_name | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -538,6 +538,101 @@ def test_struct_loss_namedtuple(self): | |||||||||||||||||||||||||
| value = compile_loss(y_true, y_pred) | ||||||||||||||||||||||||||
| self.assertAllClose(value, 1.07666, atol=1e-5) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_nested_dict_metrics(self): | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||
| from keras.src import Input | ||||||||||||||||||||||||||
| from keras.src import Model | ||||||||||||||||||||||||||
| from keras.src import layers | ||||||||||||||||||||||||||
| from keras.src import metrics as metrics_module | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Create test data matching the nested structure | ||||||||||||||||||||||||||
| y_true = { | ||||||||||||||||||||||||||
| 'a': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'd': np.random.rand(10, 10) | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| y_pred = { | ||||||||||||||||||||||||||
| 'a': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': np.random.rand(10, 10), | ||||||||||||||||||||||||||
| 'd': np.random.rand(10, 10) | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Test compiling with nested dict metrics | ||||||||||||||||||||||||||
| compile_metrics = CompileMetrics( | ||||||||||||||||||||||||||
| metrics={ | ||||||||||||||||||||||||||
| 'a': [metrics_module.MeanSquaredError()], | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], | ||||||||||||||||||||||||||
| 'd': [metrics_module.MeanSquaredError()] | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| weighted_metrics=None, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Build the metrics | ||||||||||||||||||||||||||
| compile_metrics.build(y_true, y_pred) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Update state and get results | ||||||||||||||||||||||||||
| compile_metrics.update_state(y_true, y_pred, sample_weight=None) | ||||||||||||||||||||||||||
| result = compile_metrics.result() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Check that expected metrics are present | ||||||||||||||||||||||||||
| # The actual names might be different based on how MetricsList handles output names | ||||||||||||||||||||||||||
| expected_metric_names = [] | ||||||||||||||||||||||||||
| for key in result.keys(): | ||||||||||||||||||||||||||
| if 'mean_squared_error' in key or 'mean_absolute_error' in key: | ||||||||||||||||||||||||||
| expected_metric_names.append(key) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # At least some metrics should be computed | ||||||||||||||||||||||||||
| self.assertGreater(len(expected_metric_names), 0) | ||||||||||||||||||||||||||
|
Comment on lines
+585
to
+591
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertions in this test are too weak. It only checks that
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Test with symbolic tensors as well | ||||||||||||||||||||||||||
| y_true_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_true) | ||||||||||||||||||||||||||
| y_pred_symb = tree.map_structure(lambda _: backend.KerasTensor((10, 10)), y_pred) | ||||||||||||||||||||||||||
| compile_metrics_symbolic = CompileMetrics( | ||||||||||||||||||||||||||
| metrics={ | ||||||||||||||||||||||||||
| 'a': [metrics_module.MeanSquaredError()], | ||||||||||||||||||||||||||
| 'b': { | ||||||||||||||||||||||||||
| 'c': [metrics_module.MeanSquaredError(), metrics_module.MeanAbsoluteError()], | ||||||||||||||||||||||||||
| 'd': [metrics_module.MeanSquaredError()] | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||
| weighted_metrics=None, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compile_metrics_symbolic.build(y_true_symb, y_pred_symb) | ||||||||||||||||||||||||||
| self.assertTrue(compile_metrics_symbolic.built) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_nested_dict_metrics(): | ||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||
| from keras.src import layers | ||||||||||||||||||||||||||
| from keras.src import Input | ||||||||||||||||||||||||||
| from keras.src import Model | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| X = np.random.rand(100, 32) | ||||||||||||||||||||||||||
| Y1 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
| Y2 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
| Y3 = np.random.rand(100, 10) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def create_model(): | ||||||||||||||||||||||||||
| x = Input(shape=(32,)) | ||||||||||||||||||||||||||
| y1 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| y2 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| y3 = layers.Dense(10)(x) | ||||||||||||||||||||||||||
| return Model(inputs=x, outputs={'a': y1, 'b': {'c': y2, 'd': y3}}) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| model = create_model() | ||||||||||||||||||||||||||
| model.compile( | ||||||||||||||||||||||||||
| optimizer='adam', | ||||||||||||||||||||||||||
| loss={'a': 'mse', 'b': {'c': 'mse', 'd': 'mse'}}, | ||||||||||||||||||||||||||
| metrics={'a': ['mae'], 'b': {'c': 'mse', 'd': 'mae'}}, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}}) | ||||||||||||||||||||||||||
|
Comment on lines
+610
to
+634
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test function Additionally, having a standalone function with the same name as a class method in the same file is confusing. It would be better to integrate this into a test class and give it a more descriptive name, such as |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def test_struct_loss_invalid_path(self): | ||||||||||||||||||||||||||
| y_true = { | ||||||||||||||||||||||||||
| "a": { | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use
tree.is_nested?