Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
fuse_rotary_embedding,
)
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
from onnxscript.rewriter.ort_fusions.skip_normalization import (
fuse_skip_layer_normalization,
fuse_skip_rms_normalization,
Expand Down Expand Up @@ -104,6 +105,7 @@ def fuse(func, **kwargs):
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gelu"] = fuse(fuse_gelu)
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha)
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
Expand Down
38 changes: 32 additions & 6 deletions onnxscript/rewriter/ort_fusions/sdpa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
return attn_output


# This tests a scenario where the key is in BSHd format instead of BHSd, which
# happens due to an optimization that fuses two transposes together, the one
# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first
# transpose down below is different from other test cases.
@script()
def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
scaled_query = op.Div(query, divisor)
scaled_key = op.Div(key_transposed, divisor)
attn_score = op.MatMul(scaled_query, scaled_key)
attn_weight = op.Softmax(attn_score, axis=-1)
is_nan = op.IsNaN(attn_weight)
zero = op.Constant(value_float=0.0)
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
attn_output = op.MatMul(adj_attn_weight, value)
return attn_output


class SDPATestCase:
def __init__(self, script_func, *, with_mask):
def __init__(self, script_func, *, with_mask, BSHd_key=False):
self.script_func = script_func
self.with_mask = with_mask
self.BSHd_key = BSHd_key

def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
qkv_type = FLOAT[B, N, S, H]
qv_type = FLOAT[B, N, S, H]
mask_type = FLOAT[B, N, S, S]
input_types = [qkv_type, qkv_type, qkv_type]
k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H]
input_types = [qv_type, k_type, qv_type]
if self.with_mask:
input_types.append(mask_type)
model_proto = self.script_func.to_model_proto(
input_types=input_types, output_types=[qkv_type]
input_types=input_types, output_types=[qv_type]
)
self._onnx_model = ir.serde.deserialize_model(model_proto)
return self._onnx_model
Expand All @@ -314,7 +335,9 @@ def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
"key": numpy.random.rand(B, S, N, H).astype(numpy.float32)
if self.BSHd_key
else numpy.random.rand(B, N, S, H).astype(numpy.float32),
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
}
if self.with_mask:
Expand Down Expand Up @@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase):
"_custom_multi_scale_pre_mul_sdpa_script",
_custom_multi_scale_pre_mul_sdpa_script,
),
("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script),
]
)
def test_sdpa_fusion(self, name, script_func):
test_case = SDPATestCase(script_func, with_mask="masked" in name)
test_case = SDPATestCase(
script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name
)
model = test_case.get_onnx_model()
onnxscript.optimizer.optimize(model)

Expand Down
25 changes: 19 additions & 6 deletions onnxscript/rewriter/ort_fusions/sdpa_via_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,56 @@
import onnx_ir as ir

from onnxscript.rewriter import _fusion_utils, pattern
from onnxscript.rewriter._basics import MatchFailureError

Dim = Union[int, ir.SymbolicDim]


class SDPAImplementation(pattern.RewriteRuleClassBase):
def pattern(self, op, query, key, value):
def pattern(self, op, query, key, value, key_format):
return op.SDPA(
query,
key,
value,
key_format="BHSd",
key_format=key_format,
_allow_other_inputs=True, # Mask is optional
_outputs=["sdpa_output"],
_domain="ai.onnxruntime._fusion",
)

def check(self, context, query, key, value, sdpa_output):
def check(self, context, query, key, value, key_format, sdpa_output):
bindings: dict[str, Dim] = {}
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])

if key_format.value == "BHSd":
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
elif key_format.value == "BSHd":
_fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"])
else:
raise MatchFailureError(
f"Unexpected key_format value: {key_format.value}", key_format
)

self._num_heads = bindings["H"]
if not isinstance(self._num_heads, int):
return False
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
return isinstance(self._num_heads, int)

def rewrite(self, op, query, key, value, sdpa_output):
def rewrite(self, op, query, key, value, key_format, sdpa_output):
sdpa_node = sdpa_output.producer()
scale = sdpa_node.attributes.get("scale", None)
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)

if key_format.value == "BHSd":
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
else: # BSHd
key_3d = op.Reshape(key, to_3d_shape)

inputs = [query_3d, key_3d, value_3d]
if len(sdpa_node.inputs) > 3:
mask = sdpa_node.inputs[3]
Expand Down
Loading