Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
fuse_rotary_embedding,
)
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
from onnxscript.rewriter.ort_fusions.skip_normalization import (
fuse_skip_layer_normalization,
fuse_skip_rms_normalization,
Expand Down Expand Up @@ -104,6 +105,7 @@ def fuse(func, **kwargs):
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gelu"] = fuse(fuse_gelu)
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha)
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/rewriter/ort_fusions/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@

Dim = Union[int, ir.SymbolicDim]

# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention
# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion
# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one
# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns.
# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor.
# Currently, it is restricted to query, key, and values of rank 4 with shapes:
# Query: [batch_size, num_heads, seq_len, head_size_qk]
# Key: [batch_size, num_heads, seq_len_kv, head_size_qk]
# or [batch_size, seq_len_kv, num_heads, head_size_qk])
# Value: [batch_size, num_heads, seq_len_kv, head_size_v]
# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd".


class SDPA(pattern.RewriteRuleClassBase):
_scale: float | None
Expand Down
38 changes: 32 additions & 6 deletions onnxscript/rewriter/ort_fusions/sdpa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,41 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
return attn_output


# This tests a scenario where the key is in BSHd format instead of BHSd, which
# happens due to an optimization that fuses two transposes together, the one
# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first
# transpose down below is different from other test cases.
@script()
def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value):
key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
scaled_query = op.Div(query, divisor)
scaled_key = op.Div(key_transposed, divisor)
attn_score = op.MatMul(scaled_query, scaled_key)
attn_weight = op.Softmax(attn_score, axis=-1)
is_nan = op.IsNaN(attn_weight)
zero = op.Constant(value_float=0.0)
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
attn_output = op.MatMul(adj_attn_weight, value)
return attn_output


class SDPATestCase:
def __init__(self, script_func, *, with_mask):
def __init__(self, script_func, *, with_mask, BSHd_key=False):
self.script_func = script_func
self.with_mask = with_mask
self.BSHd_key = BSHd_key

def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
qkv_type = FLOAT[B, N, S, H]
qv_type = FLOAT[B, N, S, H]
mask_type = FLOAT[B, N, S, S]
input_types = [qkv_type, qkv_type, qkv_type]
k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H]
input_types = [qv_type, k_type, qv_type]
if self.with_mask:
input_types.append(mask_type)
model_proto = self.script_func.to_model_proto(
input_types=input_types, output_types=[qkv_type]
input_types=input_types, output_types=[qv_type]
)
self._onnx_model = ir.serde.deserialize_model(model_proto)
return self._onnx_model
Expand All @@ -314,7 +335,9 @@ def get_ort_inputs(self):
if not hasattr(self, "_ort_inputs"):
inputs = {
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
"key": numpy.random.rand(B, S, N, H).astype(numpy.float32)
if self.BSHd_key
else numpy.random.rand(B, N, S, H).astype(numpy.float32),
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
}
if self.with_mask:
Expand Down Expand Up @@ -374,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase):
"_custom_multi_scale_pre_mul_sdpa_script",
_custom_multi_scale_pre_mul_sdpa_script,
),
("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script),
]
)
def test_sdpa_fusion(self, name, script_func):
test_case = SDPATestCase(script_func, with_mask="masked" in name)
test_case = SDPATestCase(
script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name
)
model = test_case.get_onnx_model()
onnxscript.optimizer.optimize(model)

Expand Down
26 changes: 20 additions & 6 deletions onnxscript/rewriter/ort_fusions/sdpa_via_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,57 @@
import onnx_ir as ir

from onnxscript.rewriter import _fusion_utils, pattern
from onnxscript.rewriter._basics import MatchFailureError

Dim = Union[int, ir.SymbolicDim]


class SDPAImplementation(pattern.RewriteRuleClassBase):
def pattern(self, op, query, key, value):
def pattern(self, op, query, key, value, key_format):
"""Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op."""
return op.SDPA(
query,
key,
value,
key_format="BHSd",
key_format=key_format,
_allow_other_inputs=True, # Mask is optional
_outputs=["sdpa_output"],
_domain="ai.onnxruntime._fusion",
)

def check(self, context, query, key, value, sdpa_output):
def check(self, context, query, key, value, key_format, sdpa_output):
bindings: dict[str, Dim] = {}
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])

if key_format.value == "BHSd":
_fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"])
elif key_format.value == "BSHd":
_fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"])
else:
raise MatchFailureError(
f"Unexpected key_format value: {key_format.value}", key_format
)

self._num_heads = bindings["H"]
if not isinstance(self._num_heads, int):
return False
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
return isinstance(self._num_heads, int)

def rewrite(self, op, query, key, value, sdpa_output):
def rewrite(self, op, query, key, value, key_format, sdpa_output):
sdpa_node = sdpa_output.producer()
scale = sdpa_node.attributes.get("scale", None)
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)

if key_format.value == "BHSd":
key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape)
else: # BSHd
key_3d = op.Reshape(key, to_3d_shape)

inputs = [query_3d, key_3d, value_3d]
if len(sdpa_node.inputs) > 3:
mask = sdpa_node.inputs[3]
Expand Down
Loading