Skip to content

Skip invalid hardcode configs #1081

@mengluy0125

Description

@mengluy0125

Is your feature request related to a problem? Please describe.
When a configuration is hardcoded for a specific tensor shape, it should not be applied to other shapes. Using such configs with incompatible shapes can result in errors, such as out-of-memory (OOM) issues. In fact, the current implementation will throw an error if an invalid config is used for a new shape. Ideally, any invalid configurations for a new tensor shape should be ignored, and if no valid config is found, the system should automatically re-autotune to determine an appropriate configuration.

An internal example:

WARNING:helion.runtime.kernel:# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided

@helion.kernel(config=helion.Config(block_sizes=[64], indexing=['pointer', 'tensor_descriptor', 'pointer', 'pointer'], load_eviction_policies=['last', 'last'], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[None]), static_shapes=True)
def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]:
    """
    Performs 2 * x * sigmoid(x @ w), we consider a special case where n is small
    Args:
        x: 2D tensor of shape [m, n].
        w: 2D tensor of shape [n, n].
    Returns:
        out: Resulting matrix of shape [m, n].
        s: sigmoid(x @ w) of shape [m, n].
    """
    m, n = x.size()

    out = torch.empty([m, n], dtype=x.dtype, device=x.device)
    s = torch.empty([m, n], dtype=x.dtype, device=x.device)

    for tile_m in hl.tile(m):
        x_tile = x[tile_m, :]  # [tile_m, n]
        # No tiling for n - load entire n at once since it's small
        sigmoid_result = torch.sigmoid(
            x_tile @ w[:, :]
        )  # [tile_m, n] @ [n, n] = [tile_m, n]
        s[tile_m, :] = sigmoid_result
        # Compute output: 2 * x * sigmoid, cast to input dtype
        acc = 2.0 * x_tile * sigmoid_result
        out[tile_m, :] = acc.to(x.dtype)

    return out, s

def helion_repro_caller():
    torch.manual_seed(0)
    x = rand_strided((1152000, 384), (384, 1), dtype=torch.bfloat16, device='cuda:0')
    x.requires_grad_(True)
    w = rand_strided((384, 384), (384, 1), dtype=torch.bfloat16, device='cuda:0')
    return se_block_fwd(x, w)

helion_repro_caller()
# === END HELION KERNEL REPRO ===
  0%|                                                                                                                                             | 0/2 [00:00<?, ?it/s]
WARNING:tritonbench.utils.triton_op:Caught exception on backend helion_se_block, terminating early with partial results
Traceback (most recent call last):
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1156, in run
    y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
                                                  ^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1137, in _reduce_benchmarks
    acc[bm_name] = self._do_bench(
                   ^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1714, in _do_bench
    metrics.latency = do_bench_wrapper(
                      ^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 540, in do_bench_wrapper
    raise e
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 530, in do_bench_wrapper
    times=bench_fn(
          ^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 292, in _do_bench_profiler
    estimate_ms = triton.testing.do_bench(
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/testing.py", line 149, in do_bench
    fn()
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/operators/fb/tlx_fused_se_block/operator.py", line 84, in inner
    return helion_se_block(x, weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/ads_mkl/ops/helion/se_block.py", line 280, in helion_se_block
    return se_block(x, w)
           ^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 687, in __call__
    return self._opoverload(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_ops.py", line 836, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/autograd.py", line 110, in autograd_impl
    result = Generated.apply(*args, Metadata(keyset, keyword_only_args))  # type: ignore[attr-defined]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/autograd/function.py", line 583, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/autograd.py", line 51, in forward
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_ops.py", line 843, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 345, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_dynamo/eval_frame.py", line 1129, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 380, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/ads_mkl/ops/helion/se_block.py", line 199, in se_block
    return se_block_fwd(x, w)
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 293, in __call__
    return self.bind(args)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 656, in __call__
    return self._run(*args)
           ^^^^^^^^^^^^^^^^
  File "/var/tmp/torchinductor_mengluy/g3/cg3dhjpaa35ewgetazncbjnvun5s5gzjfawrcxllbv4cw6ptngoe.py", line 66, in se_block_fwd
    _launcher(_helion_se_block_fwd, (triton.cdiv(1152000, _BLOCK_SIZE_0),), x, w, s, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1)
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/__init__.py", line 73, in default_launcher
    return triton_kernel.run(
           ^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 755, in run
    launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 527, in launch_metadata
    self._init_handles()
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 501, in _init_handles
    raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
  File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 493, in raise_
    raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 589832, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
WARNING:tritonbench.utils.triton_op:Failing input: --input-id 0 --num-inputs 1 --input-sample-mode first-k
  (B, M, D)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions