Skip to content

Commit e675118

Browse files
ApostaChmellor
andauthored
[Add] cmdline argument parsing for KV cache offloading modules (#27621)
Signed-off-by: ApostaC <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent e2347db commit e675118

File tree

4 files changed

+142
-1
lines changed

4 files changed

+142
-1
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""Tests for KV cache offloading configuration."""
5+
6+
import pytest
7+
8+
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
9+
10+
pytestmark = pytest.mark.cpu_test
11+
12+
13+
@pytest.mark.parametrize(
14+
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
15+
[
16+
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
17+
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
18+
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
19+
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
20+
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
21+
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
22+
(None, None, 1, 1, None, None),
23+
],
24+
)
25+
def test_kv_connector(
26+
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
27+
):
28+
kv_transfer_config = (
29+
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
30+
if expected_backend is not None
31+
else None
32+
)
33+
34+
vllm_config = VllmConfig(
35+
cache_config=CacheConfig(
36+
kv_offloading_backend=kv_offloading_backend,
37+
kv_offloading_size=kv_offloading_size,
38+
),
39+
kv_transfer_config=kv_transfer_config,
40+
parallel_config=ParallelConfig(
41+
tensor_parallel_size=tp, pipeline_parallel_size=pp
42+
),
43+
)
44+
45+
# No KV transfer config expected
46+
if expected_backend is None:
47+
assert vllm_config.kv_transfer_config is expected_backend
48+
return
49+
50+
kv_transfer_config = vllm_config.kv_transfer_config
51+
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
52+
53+
assert kv_transfer_config.kv_connector == expected_backend
54+
assert kv_transfer_config.kv_role == "kv_both"
55+
56+
if kv_offloading_backend == "native":
57+
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
58+
assert kv_connector_extra_config["num_cpu_blocks"] == 0
59+
# Existing config should be preserved
60+
assert kv_connector_extra_config["existing_key"] == "existing_value"
61+
elif kv_offloading_backend == "lmcache":
62+
assert kv_connector_extra_config["lmcache.local_cpu"] is True
63+
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
64+
# Existing config should be replaced
65+
assert "existing_key" not in kv_connector_extra_config

vllm/config/cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
2525
MambaDType = Literal["auto", "float32"]
2626
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
27+
KVOffloadingBackend = Literal["native", "lmcache"]
2728

2829

2930
@config
@@ -128,6 +129,17 @@ class CacheConfig:
128129
gpu_memory_utilization. Note that kv_cache_memory_bytes
129130
(when not-None) ignores gpu_memory_utilization"""
130131

132+
kv_offloading_size: float | None = None
133+
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
134+
the total buffer size summed across all TP ranks. By default, this is set
135+
to None, which means no KV offloading is enabled. When set with
136+
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
137+
138+
kv_offloading_backend: KVOffloadingBackend | None = None
139+
"""The backend to use for KV cache offloading. Supported backends include
140+
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
141+
together with kv_offloading_size."""
142+
131143
def compute_hash(self) -> str:
132144
"""
133145
WARNING: Whenever a new field is added to this config,

vllm/config/vllm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,48 @@ def with_hf_config(
289289

290290
return replace(self, model_config=model_config)
291291

292+
def _post_init_kv_transfer_config(self) -> None:
293+
"""Update KVTransferConfig based on top-level configs in VllmConfig.
294+
295+
Right now, this function reads the offloading settings from
296+
CacheConfig and configures the KVTransferConfig accordingly.
297+
"""
298+
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
299+
return
300+
301+
# If no KVTransferConfig is provided, create a default one.
302+
if self.kv_transfer_config is None:
303+
self.kv_transfer_config = KVTransferConfig()
304+
305+
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
306+
raise ValueError(
307+
"You must set kv_offloading_size when kv_offloading_backend is set."
308+
)
309+
num_kv_ranks = (
310+
self.parallel_config.tensor_parallel_size
311+
* self.parallel_config.pipeline_parallel_size
312+
)
313+
314+
if kv_offloading_backend == "native":
315+
self.kv_transfer_config.kv_connector = "OffloadingConnector"
316+
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
317+
318+
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
319+
# done after the model's KV cache is initialized
320+
self.kv_transfer_config.kv_connector_extra_config.update(
321+
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
322+
)
323+
elif kv_offloading_backend == "lmcache":
324+
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
325+
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
326+
self.kv_transfer_config.kv_connector_extra_config = {
327+
"lmcache.local_cpu": True,
328+
"lmcache.max_local_cpu_size": kv_gb_per_rank,
329+
}
330+
331+
# This is the same for all backends
332+
self.kv_transfer_config.kv_role = "kv_both"
333+
292334
def __post_init__(self):
293335
"""Verify configs are valid & consistent with each other."""
294336

@@ -646,6 +688,9 @@ def has_blocked_weights():
646688
if "-quant_fp8" not in custom_ops:
647689
custom_ops.append("+quant_fp8")
648690

691+
# Handle the KV connector configs
692+
self._post_init_kv_transfer_config()
693+
649694
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
650695
# remove the sizes that not multiple of tp_size when
651696
# enable sequence parallelism

vllm/engine/arg_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@
5454
VllmConfig,
5555
get_attr_docs,
5656
)
57-
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
57+
from vllm.config.cache import (
58+
BlockSize,
59+
CacheDType,
60+
KVOffloadingBackend,
61+
MambaDType,
62+
PrefixCachingHashAlgo,
63+
)
5864
from vllm.config.device import Device
5965
from vllm.config.model import (
6066
ConvertOption,
@@ -553,6 +559,11 @@ class EngineArgs:
553559

554560
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
555561

562+
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
563+
kv_offloading_backend: KVOffloadingBackend | None = (
564+
CacheConfig.kv_offloading_backend
565+
)
566+
556567
def __post_init__(self):
557568
# support `EngineArgs(compilation_config={...})`
558569
# without having to manually construct a
@@ -896,6 +907,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
896907
cache_group.add_argument(
897908
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
898909
)
910+
cache_group.add_argument(
911+
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
912+
)
913+
cache_group.add_argument(
914+
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
915+
)
899916

900917
# Multimodal related configs
901918
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1387,6 +1404,8 @@ def create_engine_config(
13871404
mamba_cache_dtype=self.mamba_cache_dtype,
13881405
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
13891406
mamba_block_size=self.mamba_block_size,
1407+
kv_offloading_size=self.kv_offloading_size,
1408+
kv_offloading_backend=self.kv_offloading_backend,
13901409
)
13911410

13921411
ray_runtime_env = None

0 commit comments

Comments
 (0)