Skip to content

[Bug]: Can both enable_inductor and enable_piecewise_cuda_graph be set to true? #8784

@morganliang-llm

Description

@morganliang-llm

System Info

GPU: H200

[TensorRT-LLM] TensorRT LLM version: 1.2.0rc1

CUDA version: 13.0

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Wed_Aug_20_01:58:59_PM_PDT_2025
Cuda compilation tools, release 13.0, V13.0.88
Build cuda_13.0.r13.0/compiler.36424714_0

torch 2.8.0+cu128

Who can help?

Question: Can both enable_inductor and enable_piecewise_cuda_graph be set to true?

When I use the following configuration to run gemma3 and llama3, I get an error from this code :
https://github.com/NVIDIA/TensorRT-LLM/blob/v1.2.0rc1/tensorrt_llm/_torch/compilation/piecewise_optimizer.py#L167

torch_compile_config:
  enable_fullgraph: True
  enable_inductor: True
  enable_piecewise_cuda_graph: True
  capture_num_tokens: [1024, 2048, 4096]
  enable_userbuffers: True
  max_num_streams: 3

or run quickstart_advanced.py

python3 quickstart_advanced.py \
    --model_dir /data0/morganliang/models/gemma3_12b/ \
    --kv_cache_dtype auto \
    --attention_backend TRTLLM \
    --tp_size 1 \
    --pp_size 1 \
    --kv_cache_fraction 0.8 \
    --max_batch_size 1 \
    --use_cuda_graph \
    --temperature 0.8 \
    --top_p 0.95 \
    --enable_chunked_prefill \
    --use_torch_compile \
    --use_piecewise_cuda_graph \
    --cuda_graph_padding_enable

Information

  • The official example scripts
  • [] My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

both enable_inductor and enable_piecewise_cuda_graph be set to true can be reproduce.

python3 quickstart_advanced.py \
    --model_dir /data0/morganliang/models/gemma3_12b/ \
    --kv_cache_dtype auto \
    --attention_backend TRTLLM \
    --tp_size 1 \
    --pp_size 1 \
    --kv_cache_fraction 0.8 \
    --max_batch_size 1 \
    --use_cuda_graph \
    --temperature 0.8 \
    --top_p 0.95 \
    --enable_chunked_prefill \
    --use_torch_compile \
    --use_piecewise_cuda_graph \
    --cuda_graph_padding_enable

or

python3 quickstart_advanced.py \
    --model_dir /data0/morganliang/models/Llama-3.1-8B-Instruct \
    --kv_cache_dtype auto \
    --attention_backend TRTLLM \
    --tp_size 1 \
    --pp_size 1 \
    --kv_cache_fraction 0.8 \
    --max_batch_size 1 \
    --use_cuda_graph \
    --temperature 0.8 \
    --top_p 0.95 \
    --enable_chunked_prefill \
    --use_torch_compile \
    --use_piecewise_cuda_graph \
    --cuda_graph_padding_enable

Expected behavior

The model can perform inference normally without any errors.

actual behavior

Error

[10/30/2025-07:59:15] [TRT-LLM] [E] Traceback (most recent call last):
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/executor/worker.py", line 362, in worker_main
worker: GenerationExecutorWorker = worker_cls(
^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/executor/worker.py", line 64, in init
self.setup_engine()
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/executor/base_worker.py", line 177, in setup_engine
self.engine = _create_py_executor(
^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/executor/base_worker.py", line 147, in _create_py_executor
_executor = create_executor(**args)
^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py", line 590, in create_py_executor
py_executor = create_py_executor_instance(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/_util.py", line 789, in create_py_executor_instance
return PyExecutor(
^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 272, in init
self.model_engine.warmup(self.resource_manager)
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 473, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 510, in warmup
self._run_torch_compile_warmup(resource_manager)
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 550, in _run_torch_compile_warmup
self.forward(batch,
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/utils.py", line 73, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2332, in forward
outputs = self._forward_step(inputs, gather_ids,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2395, in _forward_step
outputs = self.model_forward(
^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 2382, in model_forward
return self.model.forward(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/models/modeling_speculative.py", line 545, in forward
hidden_states = self.model(
^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 375, in call
return super().call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/models/modeling_llama.py", line 940, in forward
def forward(
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1241, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 384, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 750, in inner_fn
outs = compiled_fn(args)
^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 556, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 100, in g
return f(*args)
^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/fx/graph_module.py", line 424, in call
raise e
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/fx/graph_module.py", line 411, in call
return super(self.cls, obj).call(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<eval_with_key>.66", line 5, in forward
submod_0 = self.submod_0(arg0_1, arg1_1, arg2_1, arg3_1); arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/workspace/trtllm_1.2.0_cuda13/tensorrt-llm/tensorrt_llm/_torch/compilation/piecewise_optimizer.py", line 167, in call
entry.callable = compile_fx(entry.callable, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 109, in call
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 687, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 198, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1241, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 384, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 556, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 584, in call
return self.current_callable(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_inductor/utils.py", line 2713, in run
old_tensors, new_tensors = copy_misaligned_inputs(
^^^^^^^^^^^^^^^^^^^^^^^
File "/data0/morganliang/py_env/cuda13_trtllm/lib/python3.12/site-packages/torch/_inductor/utils.py", line 2760, in copy_misaligned_inputs
if _inp.data_ptr() % ALIGNMENT:
^^^^^^^^^^^^^^^
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

additional notes

None

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and checked the documentation and examples for answers to frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Pytorch<NV>Pytorch backend related issuesbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions