diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 83d8cdae1ed3..4bd0124ed3aa 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -33,6 +33,7 @@ EagerAdaptor, InductorAdaptor, InductorStandaloneAdaptor, + is_compile_cache_enabled, ) from .counter import compilation_counter from .inductor_pass import InductorPass @@ -240,7 +241,7 @@ def compile( assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache - if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: + if is_compile_cache_enabled(additional_inductor_config) and handle is not None: self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True @@ -612,7 +613,9 @@ def __call__( os.makedirs(local_cache_dir, exist_ok=True) self.compilation_config.local_cache_dir = local_cache_dir - disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + disable_cache = not is_compile_cache_enabled( + self.compilation_config.inductor_compile_config + ) if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index d15481b3045d..b0cdb08884a3 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -163,6 +163,23 @@ def get_inductor_factors() -> list[Any]: return factors +def is_compile_cache_enabled( + vllm_additional_inductor_config: dict[str, Any], +) -> bool: + vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get( + "force_disable_caches", False + ) + + # TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches + # with torch.compiler.config.force_disable_caches when minimum PyTorch + # version reaches 2.10 + return ( + not envs.VLLM_DISABLE_COMPILE_CACHE + and not torch._inductor.config.force_disable_caches + and not vllm_inductor_config_disable_cache + ) + + class InductorStandaloneAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler. @@ -222,7 +239,8 @@ def compile( # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) - if not envs.VLLM_DISABLE_COMPILE_CACHE: + + if is_compile_cache_enabled(compiler_config): compiled_graph.save(path=path, format=self.save_format) compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) @@ -472,10 +490,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv: config_patches=current_config, ) - # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch - # compilation cache. So turn off the checks if we disable the - # compilation cache. - if not envs.VLLM_DISABLE_COMPILE_CACHE: + # Turn off the checks if we disable the compilation cache. + if is_compile_cache_enabled(compiler_config): if hash_str is None: raise RuntimeError( "vLLM failed to compile the model. The most "