Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions vllm_ascend/distributed/mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import struct
import threading
import time
from collections import defaultdict, deque
from collections import OrderedDict, defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from uuid import uuid4

import httpx
import msgspec
Expand Down Expand Up @@ -67,6 +68,27 @@
metaserver: Optional[str]


@dataclass
class SizedDict(OrderedDict):
Comment on lines +70 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of @dataclass on the SizedDict class is unnecessary and potentially misleading. This class defines a custom __init__ method and inherits from OrderedDict, which is not a dataclass. The @dataclass decorator is designed for classes that primarily store data and can auto-generate methods like __init__ and __repr__. In this case, it provides no benefit and could cause confusion or unexpected behavior during future maintenance. It's better to define it as a regular class for clarity and correctness.

Suggested change
@dataclass
class SizedDict(OrderedDict):
class SizedDict(OrderedDict):


def __init__(self, max_size=2, *args, **kwargs):
self.max_size = max_size
super().__init__(*args, **kwargs)

def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.max_size:
self.popitem(last=False)

def __getitem__(self, key):
try:
return super().__getitem__(key)
except KeyError:
value = {}

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 87 in vllm_ascend/distributed/mooncake_layerwise_connector.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Need type annotation for "value" (hint: "value: dict[<type>, <type>] = ...") [var-annotated]
self[key] = value
return value


class KVCacheSendingLayerThread(threading.Thread):

def __init__(self,
Expand Down Expand Up @@ -365,7 +387,7 @@
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
self.engine_id = f"{vllm_config.kv_transfer_config.engine_id}-{uuid4()}"
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()

if role == KVConnectorRole.SCHEDULER:
Expand Down Expand Up @@ -695,9 +717,9 @@
self.encoder = msgspec.msgpack.Encoder()

self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict)
SizedDict()
self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict)
SizedDict()
Comment on lines 697 to +721
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new SizedDict is initialized with its default max_size of 2. This means it will only cache metadata for the two most recently used remote engines. In a deployment with more than two peer engines, this could lead to cache thrashing, where metadata is frequently evicted and then re-fetched over the network, potentially impacting performance.

It would be more robust to make this cache size configurable. For example, you could add a configuration option and pass it to the SizedDict constructor:

# In MooncakeLayerwiseConnectorWorker.__init__
max_cached_engines = self.vllm_config.kv_transfer_config.get_from_extra_config(
    'max_cached_engines', 128)  # A more reasonable default

# ...

self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
    SizedDict(max_size=max_cached_engines)
self.remote_te_port: dict[str, dict[int, int]] = \
    SizedDict(max_size=max_cached_engines)

This would allow operators to tune the cache size based on their specific deployment topology. For now, I'm suggesting a larger default.

Suggested change
self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
defaultdict(dict)
SizedDict()
self.remote_te_port: dict[str, dict[int, int]] = \
defaultdict(dict)
SizedDict()
self.remote_kv_caches_base_addr: dict[str, dict[int, list[int]]] = \
SizedDict(max_size=128)
self.remote_te_port: dict[str, dict[int, int]] = \
SizedDict(max_size=128)

self.remote_sockets_lock = threading.Lock()
self.remote_sockets: dict[ # type: ignore
str, deque[zmq.Socket]] = defaultdict( # type: ignore
Expand Down
Loading