Skip to content

[webgpu] MultiHeadAttention fails when attention_bias is used #26766

@xenova

Description

@xenova

Describe the issue

When attention_bias is used within the MultiHeadAttention node, it produces incorrect results on WebGPU Execution Provider.

To reproduce

import onnxruntime as ort
import numpy as np
from onnx import helper, TensorProto
import torch
import math

np.random.seed(42)

sequence_length = 4
hidden_size = 4
num_heads = 2

def torch_mha(q, k, v, mask, num_heads):
    B, L, H = q.shape
    head_dim = H // num_heads

    q = torch.from_numpy(q).view(B, L, num_heads, head_dim).transpose(1, 2)
    k = torch.from_numpy(k).view(B, L, num_heads, head_dim).transpose(1, 2)
    v = torch.from_numpy(v).view(B, L, num_heads, head_dim).transpose(1, 2)

    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)

    if mask is not None:
        mask = torch.from_numpy(mask)
        scores = scores + mask

    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, v)

    output = output.transpose(1, 2).contiguous().view(B, L, H)
    return output.numpy()

def run_test(use_bias):
    print(f"{'='*10} Testing {'WITH' if use_bias else 'WITHOUT'} Attention Bias {'='*10}")

    input_specs = [
        ("query", [1, sequence_length, hidden_size]),
        ("key", [1, sequence_length, hidden_size]),
        ("value", [1, sequence_length, hidden_size]),
    ]
    if use_bias:
        input_specs.append(("attention_bias", [1, 1, sequence_length, sequence_length]))

    inputs_info = [helper.make_tensor_value_info(n, TensorProto.FLOAT, s) for n, s in input_specs]
    output_info = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, sequence_length, hidden_size])]

    node = helper.make_node(
        "MultiHeadAttention",
        inputs=[
            "query", "key", "value",
            "", "",                                 # bias, key_padding_mask
            "attention_bias" if use_bias else "",   # attention_bias
        ],
        outputs=["output"],
        num_heads=num_heads,
        domain="com.microsoft",
    )

    graph = helper.make_graph([node], "MHA_Test", inputs_info, output_info)
    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)])
    model_bytes = model.SerializeToString()

    inputs = {name: np.random.randn(*shape).astype(np.float32) for name, shape in input_specs}

    providers = ["CPUExecutionProvider", "WebGpuExecutionProvider"]
    results = {}
    for provider in providers:
        session = ort.InferenceSession(model_bytes, providers=[provider])
        results[provider] = session.run(None, inputs)[0]

    mask = inputs["attention_bias"] if use_bias else None
    base_result = torch_mha(
        inputs["query"], inputs["key"], inputs["value"], mask, num_heads
    )

    for provider, result in results.items():
        diff = np.abs(base_result - result)
        max_diff = diff.max()
        print(f"PyTorch vs {provider}: Max Diff = {max_diff:.6f}")
        if max_diff > 1e-3:
            print("Differences:")
            print(diff)
    print()

run_test(use_bias=False)
run_test(use_bias=True)

produces

========== Testing WITHOUT Attention Bias ==========
PyTorch vs CPUExecutionProvider: Max Diff = 0.000000
PyTorch vs WebGpuExecutionProvider: Max Diff = 0.000000

========== Testing WITH Attention Bias ==========
PyTorch vs CPUExecutionProvider: Max Diff = 0.000000
PyTorch vs WebGpuExecutionProvider: Max Diff = 0.705346
Differences:
[[[0.0000000e+00 0.0000000e+00 1.8283391e-01 1.4414594e-01]
  [0.0000000e+00 2.9802322e-08 9.6328795e-02 1.6142820e-01]
  [0.0000000e+00 1.4901161e-08 3.5293663e-01 3.9857832e-01]
  [0.0000000e+00 0.0000000e+00 5.4670848e-02 7.0534587e-01]]]

Urgency

high

Platform

Mac

OS Version

Sequoia 15.6

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

main

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Other / Unknown

Execution Provider Library Version

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:WebGPUort-web webgpu providerplatform:webissues related to ONNX Runtime web; typically submitted using template

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions