Skip to content

Commit f485e35

Browse files
wuhang2014robertgshaw2-redhat
authored andcommitted
[Spec Decoding]Support Spec Decoding Metrics in DP Mode (vllm-project#24049)
Signed-off-by: wuhang <[email protected]> Co-authored-by: Robert Shaw <[email protected]>
1 parent a7a7c96 commit f485e35

File tree

2 files changed

+54
-38
lines changed

2 files changed

+54
-38
lines changed

vllm/v1/metrics/loggers.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,11 @@ def __init__(self,
169169
model_name = vllm_config.model_config.served_model_name
170170
max_model_len = vllm_config.model_config.max_model_len
171171

172-
if (len(self.engine_indexes) > 1
173-
and vllm_config.speculative_config is not None):
174-
raise NotImplementedError("Prometheus metrics with Spec Decoding "
175-
"with >1 EngineCore per AsyncLLM is not "
176-
"supported yet.")
177-
spec_decode_labelvalues = [
178-
vllm_config.model_config.served_model_name,
179-
str(self.engine_indexes[0])
180-
]
172+
spec_decode_labelvalues: dict[int, list[str]] = {
173+
idx: [model_name, str(idx)]
174+
for idx in engine_indexes
175+
}
176+
181177
self.spec_decoding_prom = self._spec_decoding_cls(
182178
vllm_config.speculative_config, labelnames,
183179
spec_decode_labelvalues)
@@ -530,7 +526,7 @@ def record(self,
530526

531527
if scheduler_stats.spec_decoding_stats is not None:
532528
self.spec_decoding_prom.observe(
533-
scheduler_stats.spec_decoding_stats)
529+
scheduler_stats.spec_decoding_stats, engine_idx)
534530

535531
if iteration_stats is None:
536532
return

vllm/v1/spec_decode/metrics.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -140,27 +140,32 @@ def __init__(
140140
self,
141141
speculative_config: Optional[SpeculativeConfig],
142142
labelnames: list[str],
143-
labelvalues: list[str],
143+
per_engine_labelvalues: dict[int, list[str]],
144144
):
145145
self.spec_decoding_enabled = speculative_config is not None
146146
if not self.spec_decoding_enabled:
147147
return
148148

149-
self.counter_spec_decode_num_drafts = \
150-
self._counter_cls(
151-
name="vllm:spec_decode_num_drafts",
152-
documentation="Number of spec decoding drafts.",
153-
labelnames=labelnames).labels(*labelvalues)
154-
self.counter_spec_decode_num_draft_tokens = \
155-
self._counter_cls(
156-
name="vllm:spec_decode_num_draft_tokens",
157-
documentation="Number of draft tokens.",
158-
labelnames=labelnames,).labels(*labelvalues)
159-
self.counter_spec_decode_num_accepted_tokens = \
160-
self._counter_cls(
161-
name="vllm:spec_decode_num_accepted_tokens",
162-
documentation="Number of accepted tokens.",
163-
labelnames=labelnames).labels(*labelvalues)
149+
counter_drafts = self._counter_cls(
150+
name="vllm:spec_decode_num_drafts",
151+
documentation="Number of spec decoding drafts.",
152+
labelnames=labelnames)
153+
self.counter_spec_decode_num_drafts = make_per_engine(
154+
counter_drafts, per_engine_labelvalues)
155+
156+
counter_draft_tokens = self._counter_cls(
157+
name="vllm:spec_decode_num_draft_tokens",
158+
documentation="Number of draft tokens.",
159+
labelnames=labelnames)
160+
self.counter_spec_decode_num_draft_tokens = make_per_engine(
161+
counter_draft_tokens, per_engine_labelvalues)
162+
163+
counter_accepted_tokens = self._counter_cls(
164+
name="vllm:spec_decode_num_accepted_tokens",
165+
documentation="Number of accepted tokens.",
166+
labelnames=labelnames)
167+
self.counter_spec_decode_num_accepted_tokens = make_per_engine(
168+
counter_accepted_tokens, per_engine_labelvalues)
164169

165170
assert speculative_config is not None
166171
num_spec_tokens = (speculative_config.num_speculative_tokens
@@ -171,21 +176,36 @@ def __init__(
171176
documentation="Accepted tokens per draft position.",
172177
labelnames=pos_labelnames,
173178
)
174-
self.counter_spec_decode_num_accepted_tokens_per_pos: list[
175-
prometheus_client.Counter] = []
176-
for pos in range(num_spec_tokens):
177-
pos_labelvalues = labelvalues + [str(pos)]
178-
self.counter_spec_decode_num_accepted_tokens_per_pos.append(
179-
base_counter.labels(*pos_labelvalues))
180-
181-
def observe(self, spec_decoding_stats: SpecDecodingStats):
179+
self.counter_spec_decode_num_accepted_tokens_per_pos: dict[
180+
int, list[prometheus_client.Counter]] = {
181+
idx: [
182+
base_counter.labels(*lv, str(pos))
183+
for pos in range(num_spec_tokens)
184+
]
185+
for idx, lv in per_engine_labelvalues.items()
186+
}
187+
188+
def observe(self,
189+
spec_decoding_stats: SpecDecodingStats,
190+
engine_idx: int = 0):
182191
if not self.spec_decoding_enabled:
183192
return
184-
self.counter_spec_decode_num_drafts.inc(spec_decoding_stats.num_drafts)
185-
self.counter_spec_decode_num_draft_tokens.inc(
193+
self.counter_spec_decode_num_drafts[engine_idx].inc(
194+
spec_decoding_stats.num_drafts)
195+
self.counter_spec_decode_num_draft_tokens[engine_idx].inc(
186196
spec_decoding_stats.num_draft_tokens)
187-
self.counter_spec_decode_num_accepted_tokens.inc(
197+
self.counter_spec_decode_num_accepted_tokens[engine_idx].inc(
188198
spec_decoding_stats.num_accepted_tokens)
189199
for pos, counter in enumerate(
190-
self.counter_spec_decode_num_accepted_tokens_per_pos):
200+
self.
201+
counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]):
191202
counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos])
203+
204+
205+
def make_per_engine(counter: prometheus_client.Counter,
206+
per_engine_labelvalues: dict[int, list[str]]):
207+
"""Create a counter for each label value."""
208+
return {
209+
idx: counter.labels(*labelvalues)
210+
for idx, labelvalues in per_engine_labelvalues.items()
211+
}

0 commit comments

Comments
 (0)