-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Labels
Description
Repro on B200:
CUDA_LAUNCH_BLOCKING=1 HELION_AUTOTUNE_RANDOM_SEED=3530777511 python benchmarks/run.py --op grouped_gemm --metrics speedup,accuracy --latency-measure-mode profiler --only helion --only-match-mode prefix-with-baseline --baseline aten_grouped_mm
Error:
Caught exception on backend helion_grouped_gemm_jagged_persistent_tritonbench, terminating early with partial results
Traceback (most recent call last):
File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 213, in benchmark_function
output = fn(*self.args) # make sure the kernel is compiled
^^^^^^^^^^^^^^
File "/tmp/torchinductor_willfeng/gr/cgrbt33rpwzlcggscxbbfkh62acuatbdqcbdm3cqpxgjur3flxc7.py", line 147, in grouped_gemm_jagged_persistent
_launcher(_helion_grouped_gemm_jagged_persistent, (_NUM_SM,), group_offsets, A_packed, B, out, num_workers, _NUM_SM, 64, 64, _BLOCK_SIZE_5, num_warps=8, num_stages=3)
File "/home/willfeng/local/helion/helion/runtime/__init__.py", line 66, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 757, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/willfeng/local/pytorch-nightly/triton/backends/nvidia/driver.py", line 712, in __call__
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: misaligned address
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1115, in run
y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1098, in _reduce_benchmarks
acc[bm_name] = self._do_bench(
^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1582, in _do_bench
metrics.latency = do_bench_wrapper(
^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 492, in do_bench_wrapper
raise e
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 482, in do_bench_wrapper
times=bench_fn(
^^^^^^^^^
File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 268, in _do_bench_profiler
estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
return fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/pytorch-nightly/torch/_inductor/runtime/benchmarking.py", line 250, in benchmark_gpu
_callable()
File "/home/willfeng/local/helion/examples/grouped_gemm.py", line 289, in inner
return grouped_gemm_jagged_persistent(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 292, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 626, in __call__
self.autotune(args)
File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 511, in autotune
config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/autotuner/base_cache.py", line 178, in autotune
config = self.autotuner.autotune()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 363, in autotune
best = self._autotune()
^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/autotuner/pattern_search.py", line 63, in _autotune
self.parallel_benchmark_population(self.population, desc="Initial population")
File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 517, in parallel_benchmark_population
self.parallel_benchmark([m.config for m in members], desc=desc),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 345, in parallel_benchmark
results.append((config, fn, self.benchmark_function(config, fn)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 239, in benchmark_function
raise exc.TritonError(
helion.exc.TritonError: Error running generated Triton program:
@helion.kernel(config=helion.Config(block_sizes=[64, 64, 16], indexing='tensor_descriptor', load_eviction_policies=['', 'first', 'last', ''], num_stages=3, num_warps=8,
pid_type='persistent_blocked', range_flattens=[None, True, True, True], range_multi_buffers=[True, None, None, True], range_num_stages=[4, 1, 1, 0], range_unroll_factors=[0, 3, 0, 4],
static_ranges=[False]), static_shapes=True)
RuntimeError: Triton Error [CUDA]: misaligned address
Generated Triton code:
from __future__ import annotations
import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher
@triton.jit
def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, num_workers, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_5:
tl.constexpr):
total_pids = num_workers
block_size = tl.cdiv(total_pids, _NUM_SM)
start_pid = tl.program_id(0) * block_size
end_pid = tl.minimum(start_pid + block_size, total_pids)
for virtual_pid in tl.range(start_pid, end_pid, disallow_acc_multi_buffer=False):
pid_0 = virtual_pid
offset_2 = pid_0
for offset_3 in tl.range(0, 4, loop_unroll_factor=3, flatten=True):
group_start = tl.load(group_offsets + offset_3 * 1, None)
add = 1 + offset_3
group_end = tl.load(group_offsets + add * 1, None, eviction_policy='evict_first')
v_0 = group_end - group_start
v_1 = tl.full([], 0, tl.int32)
v_2 = v_0 > v_1
if v_2:
v_0_copy = v_0
group_start_copy = group_start
group_end_copy = group_end
v_0_copy_0 = v_0_copy
group_start_copy_0 = group_start_copy
group_end_copy_0 = group_end_copy
_BLOCK_SIZE_0_ = _BLOCK_SIZE_0
v_3 = tl.cast(v_0_copy_0, tl.int64)
v_4 = v_3 + _BLOCK_SIZE_0_
v_5 = tl.full([], 1, tl.int32)
v_6 = v_4 - v_5
_BLOCK_SIZE_0__1 = _BLOCK_SIZE_0
v_7 = tl.cast(v_6, tl.int64)
v_8 = tl.where((v_7 < 0) != (_BLOCK_SIZE_0__1 < 0), tl.where(v_7 % _BLOCK_SIZE_0__1 != 0, v_7 // _BLOCK_SIZE_0__1 - 1, v_7 // _BLOCK_SIZE_0__1), v_7 // _BLOCK_SIZE_0__1)
add_1 = 128 + _BLOCK_SIZE_1
sub_1 = 127 + _BLOCK_SIZE_1
floordiv = triton_helpers.div_floor_integer(127 + _BLOCK_SIZE_1, _BLOCK_SIZE_1)
v_9 = tl.cast(v_8, tl.int64)
v_10 = v_9 * floordiv
for offset_4 in tl.range(0, v_10.to(tl.int32), flatten=True):
v_10_copy = v_10
v_8_copy = v_8
group_start_copy_0_copy = group_start_copy_0
group_end_copy_0_copy = group_end_copy_0
v_10_copy_0 = v_10_copy
v_8_copy_0 = v_8_copy
group_start_copy_0_copy_0 = group_start_copy_0_copy
group_end_copy_0_copy_0 = group_end_copy_0_copy
mul = num_workers * offset_4
add_2 = offset_2 + num_workers * offset_4
v_11 = tl.cast(v_10_copy_0, tl.int64)
v_12 = v_11 > add_2
if v_12:
v_8_copy_0_copy = v_8_copy_0
group_start_copy_0_copy_0_copy = group_start_copy_0_copy_0
group_end_copy_0_copy_0_copy = group_end_copy_0_copy_0
v_8_copy_0_copy_0 = v_8_copy_0_copy
group_start_copy_0_copy_0_copy_0 = group_start_copy_0_copy_0_copy
group_end_copy_0_copy_0_copy_0 = group_end_copy_0_copy_0_copy
v_13 = tl.cast(v_8_copy_0_copy_0, tl.int64)
v_14 = add_2 % v_13
v_15 = tl.full([], 0, tl.int32)
v_16 = v_14 != v_15
v_17 = libdevice.signbit(v_14) != 0 if v_14.dtype is tl.float32 else v_14 < 0
v_18 = libdevice.signbit(v_13) != 0 if v_13.dtype is tl.float32 else v_13 < 0
v_19 = v_17 != v_18
v_20 = v_16 & v_19
v_21 = v_14 + v_13
v_22 = tl.where(v_20, v_21, v_14)
v_23 = tl.cast(v_8_copy_0_copy_0, tl.int64)
v_24 = tl.where((add_2 < 0) != (v_23 < 0), tl.where(add_2 % v_23 != 0, add_2 // v_23 - 1, add_2 // v_23), add_2 // v_23)
_BLOCK_SIZE_0__2 = _BLOCK_SIZE_0
v_25 = tl.cast(v_22, tl.int64)
v_26 = v_25 * _BLOCK_SIZE_0__2
v_27 = group_start_copy_0_copy_0_copy_0 + v_26
_BLOCK_SIZE_1_ = _BLOCK_SIZE_1
v_28 = tl.cast(v_24, tl.int64)
v_29 = v_28 * _BLOCK_SIZE_1_
iota = tl.arange(0, _BLOCK_SIZE_0)
v_30 = v_27[None]
v_31 = v_30 + iota
iota_1 = tl.arange(0, _BLOCK_SIZE_1)
v_32 = v_29[None]
v_33 = v_32 + iota_1
v_34 = group_end_copy_0_copy_0_copy_0[None]
v_35 = v_31 < v_34
v_36 = tl.full([], 128, tl.int32)
v_37 = v_33 < v_36
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_5 in tl.range(0, 128, _BLOCK_SIZE_5, loop_unroll_factor=4, disallow_acc_multi_buffer=False, flatten=True):
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32)
v_31_copy = v_31
v_35_copy = v_35
v_33_copy = v_33
v_37_copy = v_37
acc_copy = acc
v_31_copy_0 = v_31_copy
v_35_copy_0 = v_35_copy
v_33_copy_0 = v_33_copy
v_37_copy_0 = v_37_copy
acc_copy_0 = acc_copy
subscript = v_35_copy_0[:, None]
a_blk = tl.load(A_packed + (v_31_copy_0[:, None] * 128 + indices_5[None, :] * 1), subscript, other=0, eviction_policy='evict_last')
subscript_1 = v_37_copy_0[None, :]
b_blk = tl.load(B + (indices_5[:, None] * 128 + v_33_copy_0[None, :] * 1), subscript_1, other=0)
acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
subscript_2 = v_35[:, None]
subscript_3 = v_37[None, :]
v_38 = subscript_2 & subscript_3
v_39 = tl.cast(acc, tl.bfloat16)
tl.store(out + (v_31[:, None] * 128 + v_33[None, :] * 1), v_39, v_38)
def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher):
"""
Persistent grouped GEMM with dynamic tile metadata computation.
This variant computes tile assignments dynamically in the kernel,
similar to TritonBench's WS variant.
Args:
A_packed: Packed A, concatenated by rows across groups, ``[sum(M_i), K]``.
B: Shared weight matrix, ``[K, N]``.
group_offsets: Row offsets delimiting each group within ``A_packed``.
Returns:
Output tensor of shape ``[sum(M_i), N]``.
"""
device = A_packed.device
if device.type == 'xpu':
num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count
else:
num_workers = torch.cuda.get_device_properties(device.index).multi_processor_count
total_M, K = A_packed.shape
K2, N = B.shape
assert K == K2
out = torch.zeros(total_M, N, dtype=torch.promote_types(A_packed.dtype, B.dtype), device=A_packed.device)
G = group_offsets.size(0) - 1
_NUM_SM = helion.runtime.get_num_sm(A_packed.device)
_BLOCK_SIZE_5 = 16
_launcher(_helion_grouped_gemm_jagged_persistent, (_NUM_SM,), group_offsets, A_packed, B, out, num_workers, _NUM_SM, 64, 64, _BLOCK_SIZE_5, num_warps=8, num_stages=3)
return out
Failing input: --input-id 0 --num-inputs 1 --input-sample-mode first-k
x_val
-------
[tritonbench] Output result csv to /tmp/tmp9nkjc86d.csv
Initial population 88% ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━ 88/100 17.4 configs/s