Skip to content

Discover a working default config automatically #1064

@mengluy0125

Description

@mengluy0125

The default config doesn't work for all kernels, for example:

Traceback (most recent call last):
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/autotuner/base_search.py", line 177, in _compute_baseline
    baseline_output = self.kernel.compile_config(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/tmp/torchinductor_mengluy/ch/cchgyarub4aaqcrlbikzi4o3cwdpm5q2rmbj6vuwk4whj272etuh.py", line 131, in layer_norm_bwd
    _launcher(_helion_layer_norm_bwd, (triton.cdiv(4608, _BLOCK_SIZE_0),), weight, x, grad_out, mean, rstd, grad_x, grad_weight_blocks, grad_bias_blocks, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/__init__.py", line 66, in default_launcher
    return triton_kernel.run(
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 732, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 862, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 332, in compile
    module = src.make_ir(target, options, codegen_fns, module_map, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 79, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 30:30:
    # src[layer_norm.py:126-142]: ...
    for offset_3 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_2):
        indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
        mask_2 = indices_3 < tile_end
        v_0_copy = v_0
        grad_w_acc_copy = grad_w_acc
        grad_b_acc_copy = grad_b_acc
        v_0_copy_0 = v_0_copy
        grad_w_acc_copy_0 = grad_w_acc_copy
        grad_b_acc_copy_0 = grad_b_acc_copy
        # src[layer_norm.py:127]: x_mb = x[mb, :].to(torch.float32)
        load_1 = tl.load(x + (indices_3[:, None] * 40920 + indices_4[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0)
                              ^
ValueError('numel (2097152) exceeds triton maximum tensor numel (1048576)')

To Reproduce

# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing=['pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer'], load_eviction_policies=['', '', '', '', ''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)
def layer_norm_bwd(
    grad_out: torch.Tensor,
    x: torch.Tensor,
    mean: torch.Tensor,
    rstd: torch.Tensor,
    weight: torch.Tensor,
    compute_bias_grad: hl.constexpr = True,  # type: ignore[valid-type]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
    """
    Compute gradients for weight (dW) and optionally bias (dB) parameters.

    This kernel performs reduction across the batch dimension (M) to accumulate
    gradients for each feature dimension's weight and bias parameters.

    Args:
        grad_out: Gradient w.r.t layer norm output [M, N]
        x: Original input tensor [M, N]
        mean: Per-sample mean computed in forward pass [M]
        rstd: Per-sample reciprocal standard deviation from forward pass [M]
        weight: Weight parameter (used only for dtype/device info) [N]
        compute_bias_grad: Whether to compute bias gradient (default: True)

    Returns:
        (grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed)
            grad_bias is None if compute_bias_grad is False
    """

    m_block = hl.register_block_size(x.size(0))
    n = hl.specialize(x.size(1))

    grad_x = torch.empty_like(x)
    num_blocks = (x.size(0) + m_block - 1) // m_block
    grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
    grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)

    for mb_cta in hl.tile(x.size(0), block_size=m_block):
        grad_w_acc = weight.new_zeros(n, dtype=torch.float32)
        if compute_bias_grad:
            grad_b_acc = weight.new_zeros(n, dtype=torch.float32)
        weight_cta = weight[None, :].to(torch.float32)
        for mb in hl.tile(mb_cta.begin, mb_cta.end):
            x_mb = x[mb, :].to(torch.float32)
            dy_mb = grad_out[mb, :].to(torch.float32)
            mean_mb = mean[mb].to(torch.float32)
            rstd_mb = rstd[mb].to(torch.float32)

            x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]

            grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
            if compute_bias_grad:
                grad_b_acc += torch.sum(dy_mb, dim=0)  # pyright: ignore[reportPossiblyUnboundVariable]

            wdy = weight_cta * dy_mb
            c1 = torch.sum(x_hat * wdy, dim=-1) / n
            c2 = torch.sum(wdy, dim=-1) / n
            dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
            grad_x[mb, :] = dx.to(x.dtype)

        grad_weight_blocks[mb_cta.id, :] = grad_w_acc
        if compute_bias_grad:
            grad_bias_blocks[mb_cta.id, :] = grad_b_acc  # type: ignore[index]

    grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
    if compute_bias_grad:
        grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
        return grad_x, grad_weight, grad_bias
    return grad_x, grad_weight, None

def helion_repro_caller():
    torch.manual_seed(0)
    grad_out = rand_strided((4608, 40920), (40920, 1), dtype=torch.bfloat16, device='cuda:0')
    x = rand_strided((4608, 40920), (40920, 1), dtype=torch.bfloat16, device='cuda:0')
    x.requires_grad_(True)
    mean = rand_strided((4608,), (1,), dtype=torch.float32, device='cuda:0')
    rstd = rand_strided((4608,), (1,), dtype=torch.float32, device='cuda:0')
    weight = rand_strided((40920,), (1,), dtype=torch.bfloat16, device='cuda:0')
    weight.requires_grad_(True)
    compute_bias_grad = True
    return layer_norm_bwd(grad_out, x, mean, rstd, weight, compute_bias_grad)

helion_repro_caller()
# === END HELION KERNEL REPRO ===

Additional context
Internal Paste P2017639080

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions