-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Labels
ep:WebGPUort-web webgpu providerort-web webgpu providerplatform:webissues related to ONNX Runtime web; typically submitted using templateissues related to ONNX Runtime web; typically submitted using template
Description
Describe the issue
When attention_bias is used within the MultiHeadAttention node, it produces incorrect results on WebGPU Execution Provider.
To reproduce
import onnxruntime as ort
import numpy as np
from onnx import helper, TensorProto
import torch
import math
np.random.seed(42)
sequence_length = 4
hidden_size = 4
num_heads = 2
def torch_mha(q, k, v, mask, num_heads):
B, L, H = q.shape
head_dim = H // num_heads
q = torch.from_numpy(q).view(B, L, num_heads, head_dim).transpose(1, 2)
k = torch.from_numpy(k).view(B, L, num_heads, head_dim).transpose(1, 2)
v = torch.from_numpy(v).view(B, L, num_heads, head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
if mask is not None:
mask = torch.from_numpy(mask)
scores = scores + mask
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).contiguous().view(B, L, H)
return output.numpy()
def run_test(use_bias):
print(f"{'='*10} Testing {'WITH' if use_bias else 'WITHOUT'} Attention Bias {'='*10}")
input_specs = [
("query", [1, sequence_length, hidden_size]),
("key", [1, sequence_length, hidden_size]),
("value", [1, sequence_length, hidden_size]),
]
if use_bias:
input_specs.append(("attention_bias", [1, 1, sequence_length, sequence_length]))
inputs_info = [helper.make_tensor_value_info(n, TensorProto.FLOAT, s) for n, s in input_specs]
output_info = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, sequence_length, hidden_size])]
node = helper.make_node(
"MultiHeadAttention",
inputs=[
"query", "key", "value",
"", "", # bias, key_padding_mask
"attention_bias" if use_bias else "", # attention_bias
],
outputs=["output"],
num_heads=num_heads,
domain="com.microsoft",
)
graph = helper.make_graph([node], "MHA_Test", inputs_info, output_info)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)])
model_bytes = model.SerializeToString()
inputs = {name: np.random.randn(*shape).astype(np.float32) for name, shape in input_specs}
providers = ["CPUExecutionProvider", "WebGpuExecutionProvider"]
results = {}
for provider in providers:
session = ort.InferenceSession(model_bytes, providers=[provider])
results[provider] = session.run(None, inputs)[0]
mask = inputs["attention_bias"] if use_bias else None
base_result = torch_mha(
inputs["query"], inputs["key"], inputs["value"], mask, num_heads
)
for provider, result in results.items():
diff = np.abs(base_result - result)
max_diff = diff.max()
print(f"PyTorch vs {provider}: Max Diff = {max_diff:.6f}")
if max_diff > 1e-3:
print("Differences:")
print(diff)
print()
run_test(use_bias=False)
run_test(use_bias=True)produces
========== Testing WITHOUT Attention Bias ==========
PyTorch vs CPUExecutionProvider: Max Diff = 0.000000
PyTorch vs WebGpuExecutionProvider: Max Diff = 0.000000
========== Testing WITH Attention Bias ==========
PyTorch vs CPUExecutionProvider: Max Diff = 0.000000
PyTorch vs WebGpuExecutionProvider: Max Diff = 0.705346
Differences:
[[[0.0000000e+00 0.0000000e+00 1.8283391e-01 1.4414594e-01]
[0.0000000e+00 2.9802322e-08 9.6328795e-02 1.6142820e-01]
[0.0000000e+00 1.4901161e-08 3.5293663e-01 3.9857832e-01]
[0.0000000e+00 0.0000000e+00 5.4670848e-02 7.0534587e-01]]]
Urgency
high
Platform
Mac
OS Version
Sequoia 15.6
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
main
ONNX Runtime API
Python
Architecture
ARM64
Execution Provider
Other / Unknown
Execution Provider Library Version
No response
Metadata
Metadata
Assignees
Labels
ep:WebGPUort-web webgpu providerort-web webgpu providerplatform:webissues related to ONNX Runtime web; typically submitted using templateissues related to ONNX Runtime web; typically submitted using template