Skip to content

Commit 971f9bb

Browse files
authored
Merge metadata props in rewriter (#2682)
Introduce basic infrastructure for merging metadata props (for use in rewriter/optimizer etc.) A basic version added to rewriter. TODO: * Allow user control over this: should this be configurable at the level of a RewriteRuleSet? Or, perhaps at a global level (given that ORT fusions uses a number of rewrite-rule-sets for various reasons)? * This [line](https://github.com/microsoft/onnxscript/blob/1a27df145b7ec03da7d316a38c2cb005cf0a45b7/onnxscript/rewriter/ort_fusions/_core.py#L148) should also be factored out or made user-controllable in some fashion. Otherwise, the metadata gets lost anyway. --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 8845fb2 commit 971f9bb

File tree

3 files changed

+116
-0
lines changed

3 files changed

+116
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Sequence, TypeVar, Union
66

77
__all__ = [
8+
"merge_metadata",
89
"pattern",
910
"rewrite",
1011
"RewritePass",
@@ -31,6 +32,7 @@
3132
RewriteRule,
3233
RewriteRuleClassBase,
3334
RewriteRuleSet,
35+
merge_metadata,
3436
)
3537
from onnxscript.rewriter.rules.common import (
3638
_basic_rules,

onnxscript/rewriter/_rewrite_rule.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import onnxscript.rewriter._ir_utils as _ir_utils
1919
import onnxscript.rewriter._matcher as _matcher
2020
import onnxscript.rewriter._pattern_ir as _pattern_ir
21+
import onnxscript.utils.metadata_merger as metadata_merger
2122
from onnxscript import ir
2223
from onnxscript.ir import _tape, convenience
2324

@@ -614,6 +615,15 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:
614615
overload += 1
615616

616617

618+
_default_metadata_merger: metadata_merger.MetadataMerger = metadata_merger.MetadataMerger(
619+
{RULE_NAME_TAG: metadata_merger.comma_separator_merger}
620+
)
621+
622+
# TODO(rama): Generalize this to support custom metadata mergers. For now, we just allow
623+
# enabling/disabling the default merger.
624+
merge_metadata: bool = True
625+
626+
617627
class RewriteRuleSet:
618628
def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None:
619629
if not rules:
@@ -740,6 +750,11 @@ def _apply_to_graph_or_function(
740750
delta.new_outputs,
741751
)
742752

753+
if merge_metadata:
754+
_default_metadata_merger.copy_merged_metadata(
755+
delta.match.nodes, delta.new_nodes
756+
)
757+
743758
count += 1
744759
break
745760

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Merging metadata_props"""
4+
5+
from __future__ import annotations
6+
7+
from typing import Callable, Iterable
8+
9+
import onnx_ir as ir
10+
11+
# Utilities for merging metadata properties, represented as strings.
12+
# The merging-logic will take care of special cases like missing metadata or
13+
# empty string metadata, and so the functions defined below need not handle
14+
# special cases like empty string. (This does assume that an empty string is
15+
# the same as no metadata, which is a reasonable assumption for most metadata.)
16+
17+
StringMerger = Callable[[str, str], str]
18+
19+
20+
def overwrite(_: str, new: str) -> str:
21+
return new
22+
23+
24+
def join(separator: str) -> StringMerger:
25+
"""Creates a StringMerger that joins two strings with the given separator.
26+
27+
Args:
28+
separator (str): The separator to use when joining the strings.
29+
30+
Returns:
31+
StringMerger: A function that joins two strings with the specified separator.
32+
"""
33+
34+
def merger(first: str, second: str) -> str:
35+
return f"{first}{separator}{second}"
36+
37+
return merger
38+
39+
40+
comma_separator_merger = join(", ")
41+
42+
43+
class MetadataMerger:
44+
"""Merges metadata properties using specified merging logic.
45+
46+
Attributes:
47+
mergers: A mapping from metadata property keys to their corresponding merging functions.
48+
default: The default merging function to use when a specific key does not have a defined merger.
49+
If None, the first value is used. (Specify `overwrite` to always use the second value.)
50+
"""
51+
52+
def __init__(
53+
self, mergers: dict[str, StringMerger], default: StringMerger | None = None
54+
) -> None:
55+
self.mergers = mergers
56+
self.default = default
57+
58+
def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None:
59+
"""Updates the first metadata property dictionary with values from the second.
60+
61+
Args:
62+
updated: The metadata dictionary to be updated.
63+
updates: The updates metadata dictionary.
64+
"""
65+
for key, new_value in updates.items():
66+
if new_value == "":
67+
continue
68+
if (key in updated) and ((updated_value := updated[key]) != ""):
69+
merger = self.mergers.get(key, self.default)
70+
if merger is not None:
71+
updated[key] = merger(updated_value, new_value)
72+
else:
73+
updated[key] = new_value
74+
75+
def copy_merged_metadata(
76+
self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node]
77+
) -> None:
78+
"""Merges metadata from multiple nodes and assigns it to one or more target nodes.
79+
80+
Args:
81+
from_nodes: The source nodes from which to merge metadata.
82+
to: The target node(s) to which the merged metadata will be assigned.
83+
"""
84+
if isinstance(to, ir.Node):
85+
updated = to.metadata_props
86+
for node in from_nodes:
87+
self.update_dict(updated, node.metadata_props)
88+
elif len(to) == 1:
89+
# Handle single node in iterable case
90+
target_node = next(iter(to))
91+
updated = target_node.metadata_props
92+
for node in from_nodes:
93+
self.update_dict(updated, node.metadata_props)
94+
else:
95+
merged_metadata: dict[str, str] = {}
96+
for node in from_nodes:
97+
self.update_dict(merged_metadata, node.metadata_props)
98+
for target_node in to:
99+
self.update_dict(target_node.metadata_props, merged_metadata)

0 commit comments

Comments
 (0)