Skip to content

Commit 002b07c

Browse files
authored
[Bugfix] vLLM should check Inductor config for compile cache enablement status (#27637)
Signed-off-by: Yanan Cao <[email protected]>
1 parent 752ddea commit 002b07c

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

vllm/compilation/backends.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
EagerAdaptor,
3434
InductorAdaptor,
3535
InductorStandaloneAdaptor,
36+
is_compile_cache_enabled,
3637
)
3738
from .counter import compilation_counter
3839
from .inductor_pass import InductorPass
@@ -239,7 +240,7 @@ def compile(
239240
assert compiled_graph is not None, "Failed to compile the graph"
240241

241242
# store the artifact in the cache
242-
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
243+
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
243244
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
244245
compilation_counter.num_cache_entries_updated += 1
245246
self.is_cache_updated = True
@@ -611,7 +612,9 @@ def __call__(
611612
os.makedirs(local_cache_dir, exist_ok=True)
612613
self.compilation_config.local_cache_dir = local_cache_dir
613614

614-
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
615+
disable_cache = not is_compile_cache_enabled(
616+
self.compilation_config.inductor_compile_config
617+
)
615618

616619
if disable_cache:
617620
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")

vllm/compilation/compiler_interface.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,23 @@ def get_inductor_factors() -> list[Any]:
163163
return factors
164164

165165

166+
def is_compile_cache_enabled(
167+
vllm_additional_inductor_config: dict[str, Any],
168+
) -> bool:
169+
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
170+
"force_disable_caches", False
171+
)
172+
173+
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
174+
# with torch.compiler.config.force_disable_caches when minimum PyTorch
175+
# version reaches 2.10
176+
return (
177+
not envs.VLLM_DISABLE_COMPILE_CACHE
178+
and not torch._inductor.config.force_disable_caches
179+
and not vllm_inductor_config_disable_cache
180+
)
181+
182+
166183
class InductorStandaloneAdaptor(CompilerInterface):
167184
"""
168185
The adaptor for the Inductor compiler.
@@ -222,7 +239,8 @@ def compile(
222239
# Save the compiled artifact to disk in the specified path
223240
assert key is not None
224241
path = os.path.join(self.cache_dir, key)
225-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
242+
243+
if is_compile_cache_enabled(compiler_config):
226244
compiled_graph.save(path=path, format=self.save_format)
227245
compilation_counter.num_compiled_artifacts_saved += 1
228246
return compiled_graph, (key, path)
@@ -472,10 +490,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
472490
config_patches=current_config,
473491
)
474492

475-
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch
476-
# compilation cache. So turn off the checks if we disable the
477-
# compilation cache.
478-
if not envs.VLLM_DISABLE_COMPILE_CACHE:
493+
# Turn off the checks if we disable the compilation cache.
494+
if is_compile_cache_enabled(compiler_config):
479495
if hash_str is None:
480496
raise RuntimeError(
481497
"vLLM failed to compile the model. The most "

0 commit comments

Comments
 (0)