-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[webgpu] Optimize Attention by enhancing flash attention support #26715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Cool! I'm busy trying to fix a bug with GQA #25966 and this will help with correctness checks! |
|
Flash Attention seems to introduce significant errors in the results of my test suite. Here are the results and reproduction:
Expand to see reproduction scriptimport numpy as np
from onnx import helper, TensorProto
import onnxruntime as ort
import dataclasses
# ==========================================
# 0. Test Harness Configuration
# ==========================================
@dataclasses.dataclass
class TestConfig:
name: str
batch_size: int
seq_len: int # Number of tokens to process NOW
past_seq_len: int # Number of tokens already in cache
max_seq_len: int = 128
@property
def total_seq_len(self):
return self.past_seq_len + self.seq_len
def create_session(model_def, provider):
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 3
try:
return ort.InferenceSession(
model_def.SerializeToString(),
sess_options=sess_options,
providers=[provider]
)
except Exception as e:
print(f"⚠️ Failed to create session for {provider}: {e}")
return None
# ==========================================
# 1. The Core Comparison Function
# ==========================================
def run_test_case(cfg: TestConfig, provider: str):
print(f"[{provider.replace('ExecutionProvider', '')}] {cfg.name: <20} | In:{cfg.seq_len} Past:{cfg.past_seq_len} Total:{cfg.total_seq_len}", end="")
# Constants
NUM_HEADS = 4
HEAD_SIZE = 32
HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE
# Shapes
query_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
kv_input_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
past_shape_gqa = [cfg.batch_size, NUM_HEADS, cfg.max_seq_len, HEAD_SIZE]
past_shape_mha = [cfg.batch_size, NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]
# ----------------------------------------
# A. Build GQA Model (Full Buffer Mode)
# ----------------------------------------
gqa_node = helper.make_node(
'GroupQueryAttention',
inputs=['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seq_len'],
outputs=['output', 'present_key', 'present_value'],
domain='com.microsoft',
name='GQA_Node',
do_rotary=0,
kv_num_heads=NUM_HEADS,
num_heads=NUM_HEADS,
scale=1.0/np.sqrt(HEAD_SIZE),
)
gqa_inputs_info = [
helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_gqa),
helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_gqa),
helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, [cfg.batch_size]),
helper.make_tensor_value_info('total_seq_len', TensorProto.INT32, []),
]
gqa_graph = helper.make_graph([gqa_node], 'gqa-test', gqa_inputs_info,
[helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
gqa_model = helper.make_model(gqa_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])
# ----------------------------------------
# B. Build MHA Model (Sliced Mode)
# ----------------------------------------
mha_node = helper.make_node(
'MultiHeadAttention',
inputs=['query', 'key', 'value', '', '', '', 'past_key', 'past_value', 'past_seq_len'],
outputs=['output', 'present_key', 'present_value'],
domain='com.microsoft',
name='MHA_Node',
num_heads=NUM_HEADS,
unidirectional=1,
scale=1.0/np.sqrt(HEAD_SIZE),
)
mha_inputs_info = [
helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_mha),
helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_mha),
helper.make_tensor_value_info('past_seq_len', TensorProto.INT32, [1]),
]
mha_graph = helper.make_graph([mha_node], 'mha-test', mha_inputs_info,
[helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
mha_model = helper.make_model(mha_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])
# ----------------------------------------
# C. Data Generation
# ----------------------------------------
np.random.seed(42 + cfg.seq_len + cfg.past_seq_len)
query = np.random.rand(*query_shape).astype(np.float32)
key = np.random.rand(*kv_input_shape).astype(np.float32)
value = np.random.rand(*kv_input_shape).astype(np.float32)
past_key_full = np.random.rand(*past_shape_gqa).astype(np.float32)
past_value_full = np.random.rand(*past_shape_gqa).astype(np.float32)
past_key_sliced = past_key_full[:, :, :cfg.past_seq_len, :]
past_value_sliced = past_value_full[:, :, :cfg.past_seq_len, :]
total_seq_len_scalar = np.array(cfg.total_seq_len, dtype=np.int32)
past_seq_len_scalar = np.array([cfg.past_seq_len], dtype=np.int32)
seqlens_k_arr = np.array([cfg.total_seq_len - 1] * cfg.batch_size, dtype=np.int32)
# ----------------------------------------
# D. Execution
# ----------------------------------------
# Create Sessions
sess_gqa = create_session(gqa_model, provider)
sess_mha = create_session(mha_model, provider)
if sess_gqa is None or sess_mha is None:
print(" ... SKIP (Session Init Failed)")
return False
# Run GQA
res_gqa = sess_gqa.run(['output'], {
'query': query, 'key': key, 'value': value,
'past_key': past_key_full, 'past_value': past_value_full,
'seqlens_k': seqlens_k_arr,
'total_seq_len': total_seq_len_scalar
})
# Run MHA
res_mha = sess_mha.run(['output'], {
'query': query, 'key': key, 'value': value,
'past_key': past_key_sliced, 'past_value': past_value_sliced,
'past_seq_len': past_seq_len_scalar
})
# ----------------------------------------
# E. Validation
# ----------------------------------------
out_gqa = res_gqa[0]
out_mha = res_mha[0]
diff = np.abs(out_gqa - out_mha)
max_diff = diff.max()
base_tol = 1e-5
if max_diff < base_tol:
print(f" -> ✅ PASS (Diff: {max_diff:.2e})")
return True
else:
print(f" -> ❌ FAIL (Diff: {max_diff:.5f})")
return False
# ==========================================
# 2. Main Execution Loop
# ==========================================
test_scenarios = [
TestConfig(name="Prefill_ColdStart", batch_size=1, seq_len=16, past_seq_len=0),
TestConfig(name="Decode_Early", batch_size=1, seq_len=1, past_seq_len=16),
TestConfig(name="Decode_Deep", batch_size=1, seq_len=1, past_seq_len=64),
TestConfig(name="Speculative_Dec", batch_size=1, seq_len=4, past_seq_len=20),
TestConfig(name="Batch_Prefill", batch_size=4, seq_len=16, past_seq_len=0),
TestConfig(name="Batch_Decode", batch_size=4, seq_len=1, past_seq_len=32),
]
# Detect Providers
available = ort.get_available_providers()
target_providers = ['CPUExecutionProvider']
# Check for WebGPU
if 'WebGpuExecutionProvider' in available:
target_providers.append('WebGpuExecutionProvider')
else:
print("⚠️ WebGpuExecutionProvider not found in this environment. Skipping GPU tests.")
print(f"Testing on Providers: {target_providers}")
print("=================================================================================")
for provider in target_providers:
print(f"\n--- Testing Provider: {provider} ---")
all_passed = True
for config in test_scenarios:
if not run_test_case(config, provider):
all_passed = False
if all_passed:
print(f"🎉 {provider}: ALL SCENARIOS PASSED.")
else:
print(f"⚠️ {provider}: FAILURES DETECTED.") |
I can reproduce them. Will investigate the reason and fix them in this PR. Thanks for reporting. |
Thanks so much! :) I kept doing a deep-dive, and if you set you can get a very small minimal reproduction. What happens, is that only half of the values in the last dimension are correctly calculated. As you can see, the first half is correct, but the second half are all zeroes. |
Amazing! Thanks so much @qjia7 |
xenova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested on a bunch of other cases too; all tests pass!
|
I kept digging and designing test cases which pass on CPU, but fail on WebGPU. Full reproductionfrom typing import Optional
import numpy as np
from onnx import helper, TensorProto
import onnxruntime as ort
import dataclasses
# ==========================================
# 0. Test Harness Configuration
# ==========================================
@dataclasses.dataclass
class TestConfig:
name: str
batch_size: int
seq_len: int # Number of tokens to process NOW
past_seq_len: int # Number of tokens already in cache
max_seq_len: int = 128
num_heads: int = 2
kv_num_heads: int = 2
head_size: int = 8
# New parameters for extended testing
do_rotary: int = 0
rotary_interleaved: int = 0
local_window_size: int = -1
softcap: float = 0.0
use_rotary_cache: bool = False
custom_scale: Optional[float] = None
@property
def total_seq_len(self):
return self.past_seq_len + self.seq_len
@property
def is_gqa_specific(self):
return (self.do_rotary > 0 or
self.local_window_size != -1 or
self.softcap > 0.0 or
self.use_rotary_cache)
def create_session(model_def, provider):
sess_options = ort.SessionOptions()
# sess_options.log_severity_level = 0
try:
return ort.InferenceSession(
model_def.SerializeToString(),
sess_options=sess_options,
providers=[provider]
)
except Exception as e:
print(f"⚠️ Failed to create session for {provider}: {e}")
return None
# ==========================================
# 1. The Core Comparison Function
# ==========================================
def run_test_case(cfg: TestConfig, providers: list[str]):
print(f"{cfg.name: <20} | In:{cfg.seq_len} Past:{cfg.past_seq_len} Total:{cfg.total_seq_len} H:{cfg.num_heads} KV:{cfg.kv_num_heads}", end="")
# Constants
NUM_HEADS = cfg.num_heads
KV_NUM_HEADS = cfg.kv_num_heads
HEAD_SIZE = cfg.head_size
HIDDEN_SIZE = NUM_HEADS * HEAD_SIZE
KV_HIDDEN_SIZE = KV_NUM_HEADS * HEAD_SIZE
SCALE = cfg.custom_scale if cfg.custom_scale is not None else 1.0/np.sqrt(HEAD_SIZE)
# Shapes
query_shape = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
kv_input_shape = [cfg.batch_size, cfg.seq_len, KV_HIDDEN_SIZE]
kv_input_shape_mha = [cfg.batch_size, cfg.seq_len, HIDDEN_SIZE]
past_shape_gqa = [cfg.batch_size, KV_NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]
past_shape_mha = [cfg.batch_size, NUM_HEADS, cfg.past_seq_len, HEAD_SIZE]
# Rotary Cache Shape
cache_shape = [cfg.max_seq_len, HEAD_SIZE // 2] if cfg.use_rotary_cache else []
# ----------------------------------------
# A. Build GQA Model (Full Buffer Mode)
# ----------------------------------------
gqa_inputs = ['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seq_len']
if cfg.use_rotary_cache:
gqa_inputs.extend(['cos_cache', 'sin_cache'])
gqa_node = helper.make_node(
'GroupQueryAttention',
inputs=gqa_inputs,
outputs=['output', 'present_key', 'present_value'],
domain='com.microsoft',
name='GQA_Node',
do_rotary=cfg.do_rotary,
kv_num_heads=KV_NUM_HEADS,
num_heads=NUM_HEADS,
scale=SCALE,
rotary_interleaved=cfg.rotary_interleaved,
local_window_size=cfg.local_window_size,
softcap=cfg.softcap,
)
gqa_inputs_info = [
helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape),
helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_gqa),
helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_gqa),
helper.make_tensor_value_info('seqlens_k', TensorProto.INT32, [cfg.batch_size]),
helper.make_tensor_value_info('total_seq_len', TensorProto.INT32, []),
]
if cfg.use_rotary_cache:
gqa_inputs_info.extend([
helper.make_tensor_value_info('cos_cache', TensorProto.FLOAT, cache_shape),
helper.make_tensor_value_info('sin_cache', TensorProto.FLOAT, cache_shape),
])
gqa_graph = helper.make_graph([gqa_node], 'gqa-test', gqa_inputs_info,
[helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
gqa_model = helper.make_model(gqa_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])
# ----------------------------------------
# B. Build MHA Model (Sliced Mode)
# ----------------------------------------
mha_node = helper.make_node(
'MultiHeadAttention',
inputs=['query', 'key', 'value', '', '', '', 'past_key', 'past_value', 'past_seq_len'],
outputs=['output', 'present_key', 'present_value'],
domain='com.microsoft',
name='MHA_Node',
num_heads=NUM_HEADS,
unidirectional=1,
scale=SCALE,
)
mha_inputs_info = [
helper.make_tensor_value_info('query', TensorProto.FLOAT, query_shape),
helper.make_tensor_value_info('key', TensorProto.FLOAT, kv_input_shape_mha),
helper.make_tensor_value_info('value', TensorProto.FLOAT, kv_input_shape_mha),
helper.make_tensor_value_info('past_key', TensorProto.FLOAT, past_shape_mha),
helper.make_tensor_value_info('past_value', TensorProto.FLOAT, past_shape_mha),
helper.make_tensor_value_info('past_seq_len', TensorProto.INT32, [1]),
]
mha_graph = helper.make_graph([mha_node], 'mha-test', mha_inputs_info,
[helper.make_tensor_value_info('output', TensorProto.FLOAT, query_shape)])
mha_model = helper.make_model(mha_graph, opset_imports=[helper.make_opsetid("", 14), helper.make_opsetid("com.microsoft", 1)])
# ----------------------------------------
# C. Data Generation
# ----------------------------------------
np.random.seed(42 + cfg.seq_len + cfg.past_seq_len)
query = np.random.rand(*query_shape).astype(np.float32)
key = np.random.rand(*kv_input_shape).astype(np.float32)
value = np.random.rand(*kv_input_shape).astype(np.float32)
past_key_full = np.random.rand(*past_shape_gqa).astype(np.float32)
past_value_full = np.random.rand(*past_shape_gqa).astype(np.float32)
cos_cache = np.random.rand(*cache_shape).astype(np.float32) if cfg.use_rotary_cache else None
sin_cache = np.random.rand(*cache_shape).astype(np.float32) if cfg.use_rotary_cache else None
# Prepare MHA inputs (Repeat KV if necessary)
n_rep = NUM_HEADS // KV_NUM_HEADS
def repeat_kv(x, n_rep):
if n_rep == 1: return x
if x.ndim == 3: # [B, S, H_kv] -> [B, S, H_q]
b, s, h_kv = x.shape
head_size = h_kv // KV_NUM_HEADS
x = x.reshape(b, s, KV_NUM_HEADS, head_size)
x = np.repeat(x, n_rep, axis=2)
return x.reshape(b, s, NUM_HEADS * head_size)
elif x.ndim == 4: # [B, H_kv, S, D] -> [B, H_q, S, D]
return np.repeat(x, n_rep, axis=1)
return x
key_mha = repeat_kv(key, n_rep)
value_mha = repeat_kv(value, n_rep)
past_key_sliced = past_key_full[:, :, :cfg.past_seq_len, :]
past_value_sliced = past_value_full[:, :, :cfg.past_seq_len, :]
past_key_mha = repeat_kv(past_key_sliced, n_rep)
past_value_mha = repeat_kv(past_value_sliced, n_rep)
total_seq_len_scalar = np.array(cfg.total_seq_len, dtype=np.int32)
past_seq_len_scalar = np.array([cfg.past_seq_len], dtype=np.int32)
seqlens_k_arr = np.array([cfg.total_seq_len - 1] * cfg.batch_size, dtype=np.int32)
# ----------------------------------------
# D. Execution
# ----------------------------------------
results = {}
for provider in providers:
# Create Sessions
sess_gqa = create_session(gqa_model, provider)
sess_mha = create_session(mha_model, provider)
if sess_gqa is None or sess_mha is None:
print(f" ... SKIP ({provider} Init Failed)")
continue
# Run GQA
try:
feed_gqa = {
'query': query, 'key': key, 'value': value,
'past_key': past_key_full, 'past_value': past_value_full,
'seqlens_k': seqlens_k_arr,
'total_seq_len': total_seq_len_scalar
}
if cfg.use_rotary_cache:
feed_gqa['cos_cache'] = cos_cache
feed_gqa['sin_cache'] = sin_cache
res_gqa = sess_gqa.run(['output'], feed_gqa)
results[f"{provider}_GQA"] = res_gqa[0]
except Exception as e:
print(f" ... ERR ({provider} GQA: {e})")
# Run MHA
if not cfg.is_gqa_specific:
try:
res_mha = sess_mha.run(['output'], {
'query': query, 'key': key_mha, 'value': value_mha,
'past_key': past_key_mha, 'past_value': past_value_mha,
'past_seq_len': past_seq_len_scalar
})
results[f"{provider}_MHA"] = res_mha[0]
except Exception as e:
print(f" ... ERR ({provider} MHA: {e})")
# ----------------------------------------
# E. Validation
# ----------------------------------------
if not results:
print(" -> ⚠️ NO RESULTS")
return False
# Determine baseline
if cfg.is_gqa_specific:
baseline_key = "CPUExecutionProvider_GQA"
else:
baseline_key = "CPUExecutionProvider_MHA"
if baseline_key not in results:
# Fallback to the first available key if preferred baseline is missing
if results:
baseline_key = list(results.keys())[0]
baseline = results[baseline_key]
passed = True
max_diff_global = 0.0
failures = []
for key, val in results.items():
if key == baseline_key: continue
diff = np.abs(baseline - val)
max_diff = diff.max()
max_diff_global = max(max_diff_global, max_diff)
base_tol = 1e-4 # Slightly relaxed for cross-device
if max_diff > base_tol:
passed = False
failures.append(f"{key} (Diff: {max_diff:.2e})")
if passed:
print(f" -> ✅ PASS (Max Diff vs {baseline_key}: {max_diff_global:.2e})")
return True
else:
print(f" -> ❌ FAIL vs {baseline_key}: {', '.join(failures)}")
return False
# ==========================================
# 2. Main Execution Loop
# ==========================================
test_scenarios = [
TestConfig(name="Rotary Interleaved", batch_size=1, seq_len=4, past_seq_len=0, max_seq_len=128, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1, local_window_size=-1, softcap=0.0, use_rotary_cache=True, custom_scale=0.25),
TestConfig(name="Rotary_Window", batch_size=1, seq_len=16, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4),
TestConfig(name="All_Features", batch_size=1, seq_len=8, past_seq_len=4, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, local_window_size=4, softcap=50.0, custom_scale=1.0),
TestConfig(name="Rotary_Interleaved", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=1),
TestConfig(name="Rotary_Half", batch_size=1, seq_len=4, past_seq_len=0, num_heads=4, kv_num_heads=2, head_size=16, do_rotary=1, rotary_interleaved=0),
]
# Detect Providers
available = ort.get_available_providers()
target_providers = ['CPUExecutionProvider']
# Check for WebGPU
if 'WebGpuExecutionProvider' in available:
target_providers.append('WebGpuExecutionProvider')
else:
print("⚠️ WebGpuExecutionProvider not found in this environment. Skipping GPU tests.")
print(f"Testing on Providers: {target_providers}")
print("=================================================================================")
all_passed = True
for config in test_scenarios:
if not run_test_case(config, target_providers):
all_passed = False
if all_passed:
print("\n🎉 ALL SCENARIOS PASSED ACROSS ALL PROVIDERS.")
else:
print("\n⚠️ FAILURES DETECTED.")The first produces incorrect results (always fails when |
Reproduced locally. It seems that the q_rotary and k_rotary are not calculated correctly. Will fix them in separate PR. The last 4 tests are not using rotary cache? At least for webgpu, when do_rotary is true, cos_cache and sin_cache are required. |
Thanks!
I believe so, yes. However, for CPU it produces an output, while on WebGPU, it runs into a segmentation fault (we should either throw an error, or produce same results as CPU, imo). |

This pull request improves the WebGPU BERT attention implementation by enhancing FlashAttention support, generalizing tensor layout handling, and increasing batch size flexibility. The changes focus on supporting both BSNH and BNSH tensor layouts, enabling FlashAttention for multi-batch scenarios, and ensuring correct broadcasting and dispatch sizing for attention bias and batch dimensions.
Key improvements include:
FlashAttention Support & Generalization:
Added support for both BSNH and BNSH tensor layouts by introducing the
q_BNSHparameter and updating shader code, program classes, and kernel logic to handle either layout correctly. This includes changes in the WGSL template and C++ logic for offset calculations and program instantiation. [1] [2] [3] [4] [5] [6] [7] [8]Updated the
CanApplyFlashAttentionandApplyFlashAttentionlogic to allow multi-batch operation by removing the restriction to batch size 1 and ensuring present key/value tensors are always created for FlashAttention. [1] [2] [3]Batch & Bias Handling:
Modified dispatch group size calculations and uniform variables throughout the FlashAttention pipeline to properly account for batch size, ensuring correct parallelization for multi-batch scenarios. [1] [2] [3] [4] [5] [6] [7]
Added logic to extract and pass attention bias dimensions as uniforms for correct broadcasting in both the compute and shader code. [1] [2] [3] [4] [5]
Other Enhancements:
Improved handling of QKV format detection and generalized code to support more format variants in
CopyKVCache.Updated includes and dependencies to ensure all necessary headers for FlashAttention are present.
These changes collectively make the WebGPU BERT attention implementation more robust, flexible, and performant across different tensor layouts and batch sizes.
phi-4-mm-vision.onnx
Before
After