Skip to content

Commit a4a4f0f

Browse files
Shaoting-FengSamuel Shen
andauthored
[KV Connector] Update lmcache connector with latest compatibility (#27681)
Signed-off-by: Samuel Shen <[email protected]> Co-authored-by: Samuel Shen <[email protected]>
1 parent 0d8161b commit a4a4f0f

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
)
4545
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group
4646
from vllm.sampling_params import SamplingParams
47-
from vllm.utils import get_kv_cache_torch_dtype
4847
from vllm.utils.math_utils import cdiv
48+
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
4949
from vllm.v1.core.sched.output import SchedulerOutput
5050
from vllm.version import __version__ as VLLM_VERSION
5151

@@ -389,7 +389,7 @@ def from_request_tracker(
389389

390390

391391
def need_gpu_interm_buffer(lmcache_config: LMCacheEngineConfig):
392-
return lmcache_config.enable_pd
392+
return not lmcache_config.enable_pd
393393

394394

395395
def _calculate_mtp_layers(vllm_config, model_config):
@@ -403,6 +403,20 @@ def _calculate_mtp_layers(vllm_config, model_config):
403403
num_mtp_layers = getattr(
404404
model_config.hf_config, "num_nextn_predict_layers", 0
405405
)
406+
407+
elif vllm_config.speculative_config.use_eagle():
408+
try:
409+
draft_model_config = vllm_config.speculative_config.draft_model_config
410+
num_mtp_layers = draft_model_config.get_num_layers(
411+
vllm_config.parallel_config
412+
)
413+
logger.info("EAGLE detected %d extra layer(s)", num_mtp_layers)
414+
except Exception:
415+
logger.info(
416+
"EAGLE detected, but failed to get the number of extra layers"
417+
"falling back to 1"
418+
)
419+
num_mtp_layers = 1
406420
return num_mtp_layers
407421

408422

@@ -1208,6 +1222,10 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens: int)
12081222
if the CacheManager this allocated blocks for us.
12091223
"""
12101224

1225+
# Clear local status in lookup client when a new request is
1226+
# successfully scheduled.
1227+
self.lookup_client.clear_lookup_status(request.request_id)
1228+
12111229
kv_transfer_params = (
12121230
request.kv_transfer_params
12131231
if hasattr(request, "kv_transfer_params")

0 commit comments

Comments
 (0)