-
Notifications
You must be signed in to change notification settings - Fork 64
Open
Labels
Description
The default config doesn't work for all kernels, for example:
Traceback (most recent call last):
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/autotuner/base_search.py", line 177, in _compute_baseline
baseline_output = self.kernel.compile_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/var/tmp/torchinductor_mengluy/ch/cchgyarub4aaqcrlbikzi4o3cwdpm5q2rmbj6vuwk4whj272etuh.py", line 131, in layer_norm_bwd
_launcher(_helion_layer_norm_bwd, (triton.cdiv(4608, _BLOCK_SIZE_0),), weight, x, grad_out, mean, rstd, grad_x, grad_weight_blocks, grad_bias_blocks, _BLOCK_SIZE_0, _RDIM_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/__init__.py", line 66, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 732, in run
kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 862, in _do_compile
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 332, in compile
module = src.make_ir(target, options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/ea30a5744022cbb2/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 79, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 30:30:
# src[layer_norm.py:126-142]: ...
for offset_3 in tl.range(offset_0.to(tl.int32), tile_end.to(tl.int32), _BLOCK_SIZE_2):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_3 < tile_end
v_0_copy = v_0
grad_w_acc_copy = grad_w_acc
grad_b_acc_copy = grad_b_acc
v_0_copy_0 = v_0_copy
grad_w_acc_copy_0 = grad_w_acc_copy
grad_b_acc_copy_0 = grad_b_acc_copy
# src[layer_norm.py:127]: x_mb = x[mb, :].to(torch.float32)
load_1 = tl.load(x + (indices_3[:, None] * 40920 + indices_4[None, :] * 1), mask_2[:, None] & mask_1[None, :], other=0)
^
ValueError('numel (2097152) exceeds triton maximum tensor numel (1048576)')
To Reproduce
# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided
@helion.kernel(config=helion.Config(block_sizes=[32, 32], indexing=['pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer', 'pointer'], load_eviction_policies=['', '', '', '', ''], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[]), static_shapes=True)
def layer_norm_bwd(
grad_out: torch.Tensor,
x: torch.Tensor,
mean: torch.Tensor,
rstd: torch.Tensor,
weight: torch.Tensor,
compute_bias_grad: hl.constexpr = True, # type: ignore[valid-type]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""
Compute gradients for weight (dW) and optionally bias (dB) parameters.
This kernel performs reduction across the batch dimension (M) to accumulate
gradients for each feature dimension's weight and bias parameters.
Args:
grad_out: Gradient w.r.t layer norm output [M, N]
x: Original input tensor [M, N]
mean: Per-sample mean computed in forward pass [M]
rstd: Per-sample reciprocal standard deviation from forward pass [M]
weight: Weight parameter (used only for dtype/device info) [N]
compute_bias_grad: Whether to compute bias gradient (default: True)
Returns:
(grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed)
grad_bias is None if compute_bias_grad is False
"""
m_block = hl.register_block_size(x.size(0))
n = hl.specialize(x.size(1))
grad_x = torch.empty_like(x)
num_blocks = (x.size(0) + m_block - 1) // m_block
grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
for mb_cta in hl.tile(x.size(0), block_size=m_block):
grad_w_acc = weight.new_zeros(n, dtype=torch.float32)
if compute_bias_grad:
grad_b_acc = weight.new_zeros(n, dtype=torch.float32)
weight_cta = weight[None, :].to(torch.float32)
for mb in hl.tile(mb_cta.begin, mb_cta.end):
x_mb = x[mb, :].to(torch.float32)
dy_mb = grad_out[mb, :].to(torch.float32)
mean_mb = mean[mb].to(torch.float32)
rstd_mb = rstd[mb].to(torch.float32)
x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]
grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
if compute_bias_grad:
grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable]
wdy = weight_cta * dy_mb
c1 = torch.sum(x_hat * wdy, dim=-1) / n
c2 = torch.sum(wdy, dim=-1) / n
dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
grad_x[mb, :] = dx.to(x.dtype)
grad_weight_blocks[mb_cta.id, :] = grad_w_acc
if compute_bias_grad:
grad_bias_blocks[mb_cta.id, :] = grad_b_acc # type: ignore[index]
grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
if compute_bias_grad:
grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
return grad_x, grad_weight, grad_bias
return grad_x, grad_weight, None
def helion_repro_caller():
torch.manual_seed(0)
grad_out = rand_strided((4608, 40920), (40920, 1), dtype=torch.bfloat16, device='cuda:0')
x = rand_strided((4608, 40920), (40920, 1), dtype=torch.bfloat16, device='cuda:0')
x.requires_grad_(True)
mean = rand_strided((4608,), (1,), dtype=torch.float32, device='cuda:0')
rstd = rand_strided((4608,), (1,), dtype=torch.float32, device='cuda:0')
weight = rand_strided((40920,), (1,), dtype=torch.bfloat16, device='cuda:0')
weight.requires_grad_(True)
compute_bias_grad = True
return layer_norm_bwd(grad_out, x, mean, rstd, weight, compute_bias_grad)
helion_repro_caller()
# === END HELION KERNEL REPRO ===Additional context
Internal Paste P2017639080