Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions examples/dynamo/low_cpu_memory_compilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""

.. _low_cpu_memory_compilation:

Low CPU Memory Compilation Example
==================================

This example demonstrates compiling a model with a bounded CPU (host) memory
budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on
memory-constrained machines or when compiling very large models.

Key notes:
- The toy model below has roughly 430 MB of parameters. We set the CPU
memory budget to 2 GiB. At compile time, only about 900 MB of host RAM
may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model.
So the model is partitioned into two subgraphs to fit the memory budget.

- Performance impact varies by model. When the number of TensorRT engines
created is small, the impact is typically minimal.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.conversion import CompilationSettings


class net(nn.Module):
def __init__(self):
super().__init__()
# Intentionally large layers to stress host memory during compilation.
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
self.bn1 = nn.BatchNorm2d(4096)
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
self.bn2 = nn.BatchNorm2d(1024)
self.fc1 = nn.Linear(1024 * 56 * 56, 10)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = F.max_pool2d(x, (2, 2))
x = torch.flatten(x, 1)
return self.fc1(x)


model = net().eval()
model.to("cuda")
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]

enabled_precisions = {torch.float}
use_python_runtime = False

compilation_options = {
"use_python_runtime": use_python_runtime,
"enabled_precisions": enabled_precisions,
"min_block_size": 1,
"immutable_weights": True,
"reuse_cached_engines": False,
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
}

settings = CompilationSettings(**compilation_options)
with torchtrt.dynamo.Debugger(
log_level="debug",
logging_dir="/home/profile/logging/moe",
engine_builder_monitor=False,
):

exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torchtrt.dynamo.compile(
exp_program,
inputs=inputs,
**compilation_options,
)

# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
print(trt_gm)
21 changes: 19 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
resource_partition,
)
from torch_tensorrt.dynamo.utils import (
deallocate_module,
get_flat_args_with_check,
Expand Down Expand Up @@ -104,6 +107,7 @@ def cross_compile_for_windows(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -178,6 +182,7 @@ def cross_compile_for_windows(
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
cpu_memory_budget (int): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. If set to -1, the compilation will use all available CPU memory.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -333,6 +338,7 @@ def cross_compile_for_windows(
"tiling_optimization_level": tiling_optimization_level,
"l2_limit_for_tiling": l2_limit_for_tiling,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -434,6 +440,7 @@ def compile(
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
cpu_memory_budget: int = _defaults.CPU_MEMORY_BUDGET,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -614,6 +621,10 @@ def compile(
"'arg_inputs' and 'inputs' should not be used at the same time."
)

assert (
cpu_memory_budget >= 2 * 1024 * 1024 * 1024
), "CPU memory budget must be greater than 10GB"

arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
Expand Down Expand Up @@ -680,8 +691,8 @@ def compile(
"l2_limit_for_tiling": l2_limit_for_tiling,
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"cpu_memory_budget": cpu_memory_budget,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
exported_program = pre_export_lowering(exported_program, settings)
Expand Down Expand Up @@ -850,6 +861,12 @@ def preserve_module_specs(
require_full_compilation=settings.require_full_compilation,
)

partitioned_module = resource_partition(
gm,
partitioned_module,
cpu_memory_budget=settings.cpu_memory_budget,
)

dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators

# The global partitioner leaves non-TRT nodes as-is
Expand Down Expand Up @@ -1243,7 +1260,7 @@ def convert_exported_program_to_serialized_trt_engine(

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform
import tempfile

import psutil
import torch
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
Expand Down Expand Up @@ -57,6 +58,7 @@
L2_LIMIT_FOR_TILING = -1
USE_DISTRIBUTED_MODE_TRACE = False
OFFLOAD_MODULE_TO_CPU = False
CPU_MEMORY_BUDGET = psutil.virtual_memory().available

if platform.system() == "Linux":
import pwd
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
Expand Down Expand Up @@ -140,6 +141,7 @@ class CompilationSettings:
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
cpu_memory_budget: int = CPU_MEMORY_BUDGET
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding this arg in the _SETTINGS_TO_BE_ENGINE_INVARIANT below.


def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down
Loading