-
Notifications
You must be signed in to change notification settings - Fork 65
Open
Labels
Description
Is your feature request related to a problem? Please describe.
When a configuration is hardcoded for a specific tensor shape, it should not be applied to other shapes. Using such configs with incompatible shapes can result in errors, such as out-of-memory (OOM) issues. In fact, the current implementation will throw an error if an invalid config is used for a new shape. Ideally, any invalid configurations for a new tensor shape should be ignored, and if no valid config is found, the system should automatically re-autotune to determine an appropriate configuration.
An internal example:
WARNING:helion.runtime.kernel:# === HELION KERNEL REPRO ===
import helion
import helion.language as hl
import torch
from torch._dynamo.testing import rand_strided
@helion.kernel(config=helion.Config(block_sizes=[64], indexing=['pointer', 'tensor_descriptor', 'pointer', 'pointer'], load_eviction_policies=['last', 'last'], num_stages=1, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[0], range_unroll_factors=[0], range_warp_specializes=[None]), static_shapes=True)
def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]:
"""
Performs 2 * x * sigmoid(x @ w), we consider a special case where n is small
Args:
x: 2D tensor of shape [m, n].
w: 2D tensor of shape [n, n].
Returns:
out: Resulting matrix of shape [m, n].
s: sigmoid(x @ w) of shape [m, n].
"""
m, n = x.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
s = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m in hl.tile(m):
x_tile = x[tile_m, :] # [tile_m, n]
# No tiling for n - load entire n at once since it's small
sigmoid_result = torch.sigmoid(
x_tile @ w[:, :]
) # [tile_m, n] @ [n, n] = [tile_m, n]
s[tile_m, :] = sigmoid_result
# Compute output: 2 * x * sigmoid, cast to input dtype
acc = 2.0 * x_tile * sigmoid_result
out[tile_m, :] = acc.to(x.dtype)
return out, s
def helion_repro_caller():
torch.manual_seed(0)
x = rand_strided((1152000, 384), (384, 1), dtype=torch.bfloat16, device='cuda:0')
x.requires_grad_(True)
w = rand_strided((384, 384), (384, 1), dtype=torch.bfloat16, device='cuda:0')
return se_block_fwd(x, w)
helion_repro_caller()
# === END HELION KERNEL REPRO ===
0%| | 0/2 [00:00<?, ?it/s]
WARNING:tritonbench.utils.triton_op:Caught exception on backend helion_se_block, terminating early with partial results
Traceback (most recent call last):
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1156, in run
y_vals: Dict[str, BenchmarkOperatorMetrics] = functools.reduce(
^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1137, in _reduce_benchmarks
acc[bm_name] = self._do_bench(
^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/utils/triton_op.py", line 1714, in _do_bench
metrics.latency = do_bench_wrapper(
^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 540, in do_bench_wrapper
raise e
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 530, in do_bench_wrapper
times=bench_fn(
^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/components/do_bench/run.py", line 292, in _do_bench_profiler
estimate_ms = triton.testing.do_bench(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/testing.py", line 149, in do_bench
fn()
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/tritonbench/operators/fb/tlx_fused_se_block/operator.py", line 84, in inner
return helion_se_block(x, weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/ads_mkl/ops/helion/se_block.py", line 280, in helion_se_block
return se_block(x, w)
^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 687, in __call__
return self._opoverload(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_ops.py", line 836, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/autograd.py", line 110, in autograd_impl
result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/autograd/function.py", line 583, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/autograd.py", line 51, in forward
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_ops.py", line 843, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 345, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_compile.py", line 54, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_dynamo/eval_frame.py", line 1129, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/torch/_library/custom_ops.py", line 380, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/ads_mkl/ops/helion/se_block.py", line 199, in se_block
return se_block_fwd(x, w)
^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 293, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/kernel.py", line 656, in __call__
return self._run(*args)
^^^^^^^^^^^^^^^^
File "/var/tmp/torchinductor_mengluy/g3/cg3dhjpaa35ewgetazncbjnvun5s5gzjfawrcxllbv4cw6ptngoe.py", line 66, in se_block_fwd
_launcher(_helion_se_block_fwd, (triton.cdiv(1152000, _BLOCK_SIZE_0),), x, w, s, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1)
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/helion/runtime/__init__.py", line 73, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/runtime/jit.py", line 755, in run
launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 527, in launch_metadata
self._init_handles()
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 501, in _init_handles
raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
File "/data/users/mengluy/fbsource/buck-out/v2/gen/fbcode/7621b45984e65cb6/pytorch/tritonbench/__run__/run-inplace#link-tree/triton/compiler/compiler.py", line 493, in raise_
raise err
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 589832, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
WARNING:tritonbench.utils.triton_op:Failing input: --input-id 0 --num-inputs 1 --input-sample-mode first-k
(B, M, D)