Skip to content

[Benchmark CI] misaligned address error on grouped_gemm kernel #908

@yf225

Description

@yf225

Repro on B200:
CUDA_LAUNCH_BLOCKING=1 HELION_AUTOTUNE_RANDOM_SEED=3530777511 python benchmarks/run.py --op grouped_gemm --metrics speedup,accuracy --latency-measure-mode profiler --only helion --only-match-mode prefix-with-baseline --baseline aten_grouped_mm

Error:

Caught exception on backend helion_grouped_gemm_jagged_persistent_tritonbench, terminating early with partial results
Traceback (most recent call last):
  File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 213, in benchmark_function
    output = fn(*self.args)  # make sure the kernel is compiled
             ^^^^^^^^^^^^^^
  File "/tmp/torchinductor_willfeng/gr/cgrbt33rpwzlcggscxbbfkh62acuatbdqcbdm3cqpxgjur3flxc7.py", line 147, in grouped_gemm_jagged_persistent
    _launcher(_helion_grouped_gemm_jagged_persistent, (_NUM_SM,), group_offsets, A_packed, B, out, num_workers, _NUM_SM, 64, 64, _BLOCK_SIZE_5, num_warps=8, num_stages=3)
  File "/home/willfeng/local/helion/helion/runtime/__init__.py", line 66, in default_launcher
    return triton_kernel.run(
           ^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/triton/runtime/jit.py", line 757, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/home/willfeng/local/pytorch-nightly/triton/backends/nvidia/driver.py", line 712, in __call__
    self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: misaligned address

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1115, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1098, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/utils/triton_op.py", line 1582, in _do_bench
    metrics.latency = do_bench_wrapper(
                      ^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 492, in do_bench_wrapper
    raise e
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 482, in do_bench_wrapper
    times=bench_fn(
          ^^^^^^^^^
  File "/home/willfeng/local/helion/benchmarks/tritonbench/tritonbench/components/do_bench/run.py", line 268, in _do_bench_profiler
    estimate_ms = benchmarker.benchmark_gpu(fn, estimation_iters=5, benchmark_iters=10)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_inductor/runtime/benchmarking.py", line 250, in benchmark_gpu
    _callable()
  File "/home/willfeng/local/helion/examples/grouped_gemm.py", line 289, in inner
    return grouped_gemm_jagged_persistent(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 292, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 626, in __call__
    self.autotune(args)
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 511, in autotune
    config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/autotuner/base_cache.py", line 178, in autotune
    config = self.autotuner.autotune()
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 363, in autotune
    best = self._autotune()
           ^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/autotuner/pattern_search.py", line 63, in _autotune
    self.parallel_benchmark_population(self.population, desc="Initial population")
  File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 517, in parallel_benchmark_population
    self.parallel_benchmark([m.config for m in members], desc=desc),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 345, in parallel_benchmark
    results.append((config, fn, self.benchmark_function(config, fn)))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/autotuner/base_search.py", line 239, in benchmark_function
    raise exc.TritonError(
helion.exc.TritonError: Error running generated Triton program:
@helion.kernel(config=helion.Config(block_sizes=[64, 64, 16], indexing='tensor_descriptor', load_eviction_policies=['', 'first', 'last', ''], num_stages=3, num_warps=8, 
pid_type='persistent_blocked', range_flattens=[None, True, True, True], range_multi_buffers=[True, None, None, True], range_num_stages=[4, 1, 1, 0], range_unroll_factors=[0, 3, 0, 4], 
static_ranges=[False]), static_shapes=True)
RuntimeError: Triton Error [CUDA]: misaligned address

Generated Triton code:
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_grouped_gemm_jagged_persistent(group_offsets, A_packed, B, out, num_workers, _NUM_SM: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_5: 
tl.constexpr):
    total_pids = num_workers
    block_size = tl.cdiv(total_pids, _NUM_SM)
    start_pid = tl.program_id(0) * block_size
    end_pid = tl.minimum(start_pid + block_size, total_pids)
    for virtual_pid in tl.range(start_pid, end_pid, disallow_acc_multi_buffer=False):
        pid_0 = virtual_pid
        offset_2 = pid_0
        for offset_3 in tl.range(0, 4, loop_unroll_factor=3, flatten=True):
            group_start = tl.load(group_offsets + offset_3 * 1, None)
            add = 1 + offset_3
            group_end = tl.load(group_offsets + add * 1, None, eviction_policy='evict_first')
            v_0 = group_end - group_start
            v_1 = tl.full([], 0, tl.int32)
            v_2 = v_0 > v_1
            if v_2:
                v_0_copy = v_0
                group_start_copy = group_start
                group_end_copy = group_end
                v_0_copy_0 = v_0_copy
                group_start_copy_0 = group_start_copy
                group_end_copy_0 = group_end_copy
                _BLOCK_SIZE_0_ = _BLOCK_SIZE_0
                v_3 = tl.cast(v_0_copy_0, tl.int64)
                v_4 = v_3 + _BLOCK_SIZE_0_
                v_5 = tl.full([], 1, tl.int32)
                v_6 = v_4 - v_5
                _BLOCK_SIZE_0__1 = _BLOCK_SIZE_0
                v_7 = tl.cast(v_6, tl.int64)
                v_8 = tl.where((v_7 < 0) != (_BLOCK_SIZE_0__1 < 0), tl.where(v_7 % _BLOCK_SIZE_0__1 != 0, v_7 // _BLOCK_SIZE_0__1 - 1, v_7 // _BLOCK_SIZE_0__1), v_7 // _BLOCK_SIZE_0__1)
                add_1 = 128 + _BLOCK_SIZE_1
                sub_1 = 127 + _BLOCK_SIZE_1
                floordiv = triton_helpers.div_floor_integer(127 + _BLOCK_SIZE_1, _BLOCK_SIZE_1)
                v_9 = tl.cast(v_8, tl.int64)
                v_10 = v_9 * floordiv
                for offset_4 in tl.range(0, v_10.to(tl.int32), flatten=True):
                    v_10_copy = v_10
                    v_8_copy = v_8
                    group_start_copy_0_copy = group_start_copy_0
                    group_end_copy_0_copy = group_end_copy_0
                    v_10_copy_0 = v_10_copy
                    v_8_copy_0 = v_8_copy
                    group_start_copy_0_copy_0 = group_start_copy_0_copy
                    group_end_copy_0_copy_0 = group_end_copy_0_copy
                    mul = num_workers * offset_4
                    add_2 = offset_2 + num_workers * offset_4
                    v_11 = tl.cast(v_10_copy_0, tl.int64)
                    v_12 = v_11 > add_2
                    if v_12:
                        v_8_copy_0_copy = v_8_copy_0
                        group_start_copy_0_copy_0_copy = group_start_copy_0_copy_0
                        group_end_copy_0_copy_0_copy = group_end_copy_0_copy_0
                        v_8_copy_0_copy_0 = v_8_copy_0_copy
                        group_start_copy_0_copy_0_copy_0 = group_start_copy_0_copy_0_copy
                        group_end_copy_0_copy_0_copy_0 = group_end_copy_0_copy_0_copy
                        v_13 = tl.cast(v_8_copy_0_copy_0, tl.int64)
                        v_14 = add_2 % v_13
                        v_15 = tl.full([], 0, tl.int32)
                        v_16 = v_14 != v_15
                        v_17 = libdevice.signbit(v_14) != 0 if v_14.dtype is tl.float32 else v_14 < 0
                        v_18 = libdevice.signbit(v_13) != 0 if v_13.dtype is tl.float32 else v_13 < 0
                        v_19 = v_17 != v_18
                        v_20 = v_16 & v_19
                        v_21 = v_14 + v_13
                        v_22 = tl.where(v_20, v_21, v_14)
                        v_23 = tl.cast(v_8_copy_0_copy_0, tl.int64)
                        v_24 = tl.where((add_2 < 0) != (v_23 < 0), tl.where(add_2 % v_23 != 0, add_2 // v_23 - 1, add_2 // v_23), add_2 // v_23)
                        _BLOCK_SIZE_0__2 = _BLOCK_SIZE_0
                        v_25 = tl.cast(v_22, tl.int64)
                        v_26 = v_25 * _BLOCK_SIZE_0__2
                        v_27 = group_start_copy_0_copy_0_copy_0 + v_26
                        _BLOCK_SIZE_1_ = _BLOCK_SIZE_1
                        v_28 = tl.cast(v_24, tl.int64)
                        v_29 = v_28 * _BLOCK_SIZE_1_
                        iota = tl.arange(0, _BLOCK_SIZE_0)
                        v_30 = v_27[None]
                        v_31 = v_30 + iota
                        iota_1 = tl.arange(0, _BLOCK_SIZE_1)
                        v_32 = v_29[None]
                        v_33 = v_32 + iota_1
                        v_34 = group_end_copy_0_copy_0_copy_0[None]
                        v_35 = v_31 < v_34
                        v_36 = tl.full([], 128, tl.int32)
                        v_37 = v_33 < v_36
                        acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
                        for offset_5 in tl.range(0, 128, _BLOCK_SIZE_5, loop_unroll_factor=4, disallow_acc_multi_buffer=False, flatten=True):
                            indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32)
                            v_31_copy = v_31
                            v_35_copy = v_35
                            v_33_copy = v_33
                            v_37_copy = v_37
                            acc_copy = acc
                            v_31_copy_0 = v_31_copy
                            v_35_copy_0 = v_35_copy
                            v_33_copy_0 = v_33_copy
                            v_37_copy_0 = v_37_copy
                            acc_copy_0 = acc_copy
                            subscript = v_35_copy_0[:, None]
                            a_blk = tl.load(A_packed + (v_31_copy_0[:, None] * 128 + indices_5[None, :] * 1), subscript, other=0, eviction_policy='evict_last')
                            subscript_1 = v_37_copy_0[None, :]
                            b_blk = tl.load(B + (indices_5[:, None] * 128 + v_33_copy_0[None, :] * 1), subscript_1, other=0)
                            acc = tl.dot(tl.cast(a_blk, tl.bfloat16), tl.cast(b_blk, tl.bfloat16), acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
                        subscript_2 = v_35[:, None]
                        subscript_3 = v_37[None, :]
                        v_38 = subscript_2 & subscript_3
                        v_39 = tl.cast(acc, tl.bfloat16)
                        tl.store(out + (v_31[:, None] * 128 + v_33[None, :] * 1), v_39, v_38)

def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, group_offsets: torch.Tensor, *, _launcher=_default_launcher):
    """
    Persistent grouped GEMM with dynamic tile metadata computation.

    This variant computes tile assignments dynamically in the kernel,
    similar to TritonBench's WS variant.

    Args:
        A_packed: Packed A, concatenated by rows across groups, ``[sum(M_i), K]``.
        B: Shared weight matrix, ``[K, N]``.
        group_offsets: Row offsets delimiting each group within ``A_packed``.

    Returns:
        Output tensor of shape ``[sum(M_i), N]``.
    """
    device = A_packed.device
    if device.type == 'xpu':
        num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count
    else:
        num_workers = torch.cuda.get_device_properties(device.index).multi_processor_count
    total_M, K = A_packed.shape
    K2, N = B.shape
    assert K == K2
    out = torch.zeros(total_M, N, dtype=torch.promote_types(A_packed.dtype, B.dtype), device=A_packed.device)
    G = group_offsets.size(0) - 1
    _NUM_SM = helion.runtime.get_num_sm(A_packed.device)
    _BLOCK_SIZE_5 = 16
    _launcher(_helion_grouped_gemm_jagged_persistent, (_NUM_SM,), group_offsets, A_packed, B, out, num_workers, _NUM_SM, 64, 64, _BLOCK_SIZE_5, num_warps=8, num_stages=3)
    return out
Failing input: --input-id 0 --num-inputs 1 --input-sample-mode first-k
  x_val
-------
[tritonbench] Output result csv to /tmp/tmp9nkjc86d.csv
Initial population  88% ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  88/100 17.4 configs/s

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions