@@ -140,27 +140,32 @@ def __init__(
140140 self ,
141141 speculative_config : Optional [SpeculativeConfig ],
142142 labelnames : list [str ],
143- labelvalues : list [str ],
143+ per_engine_labelvalues : dict [ int , list [str ] ],
144144 ):
145145 self .spec_decoding_enabled = speculative_config is not None
146146 if not self .spec_decoding_enabled :
147147 return
148148
149- self .counter_spec_decode_num_drafts = \
150- self ._counter_cls (
151- name = "vllm:spec_decode_num_drafts" ,
152- documentation = "Number of spec decoding drafts." ,
153- labelnames = labelnames ).labels (* labelvalues )
154- self .counter_spec_decode_num_draft_tokens = \
155- self ._counter_cls (
156- name = "vllm:spec_decode_num_draft_tokens" ,
157- documentation = "Number of draft tokens." ,
158- labelnames = labelnames ,).labels (* labelvalues )
159- self .counter_spec_decode_num_accepted_tokens = \
160- self ._counter_cls (
161- name = "vllm:spec_decode_num_accepted_tokens" ,
162- documentation = "Number of accepted tokens." ,
163- labelnames = labelnames ).labels (* labelvalues )
149+ counter_drafts = self ._counter_cls (
150+ name = "vllm:spec_decode_num_drafts" ,
151+ documentation = "Number of spec decoding drafts." ,
152+ labelnames = labelnames )
153+ self .counter_spec_decode_num_drafts = make_per_engine (
154+ counter_drafts , per_engine_labelvalues )
155+
156+ counter_draft_tokens = self ._counter_cls (
157+ name = "vllm:spec_decode_num_draft_tokens" ,
158+ documentation = "Number of draft tokens." ,
159+ labelnames = labelnames )
160+ self .counter_spec_decode_num_draft_tokens = make_per_engine (
161+ counter_draft_tokens , per_engine_labelvalues )
162+
163+ counter_accepted_tokens = self ._counter_cls (
164+ name = "vllm:spec_decode_num_accepted_tokens" ,
165+ documentation = "Number of accepted tokens." ,
166+ labelnames = labelnames )
167+ self .counter_spec_decode_num_accepted_tokens = make_per_engine (
168+ counter_accepted_tokens , per_engine_labelvalues )
164169
165170 assert speculative_config is not None
166171 num_spec_tokens = (speculative_config .num_speculative_tokens
@@ -171,21 +176,36 @@ def __init__(
171176 documentation = "Accepted tokens per draft position." ,
172177 labelnames = pos_labelnames ,
173178 )
174- self .counter_spec_decode_num_accepted_tokens_per_pos : list [
175- prometheus_client .Counter ] = []
176- for pos in range (num_spec_tokens ):
177- pos_labelvalues = labelvalues + [str (pos )]
178- self .counter_spec_decode_num_accepted_tokens_per_pos .append (
179- base_counter .labels (* pos_labelvalues ))
180-
181- def observe (self , spec_decoding_stats : SpecDecodingStats ):
179+ self .counter_spec_decode_num_accepted_tokens_per_pos : dict [
180+ int , list [prometheus_client .Counter ]] = {
181+ idx : [
182+ base_counter .labels (* lv , str (pos ))
183+ for pos in range (num_spec_tokens )
184+ ]
185+ for idx , lv in per_engine_labelvalues .items ()
186+ }
187+
188+ def observe (self ,
189+ spec_decoding_stats : SpecDecodingStats ,
190+ engine_idx : int = 0 ):
182191 if not self .spec_decoding_enabled :
183192 return
184- self .counter_spec_decode_num_drafts .inc (spec_decoding_stats .num_drafts )
185- self .counter_spec_decode_num_draft_tokens .inc (
193+ self .counter_spec_decode_num_drafts [engine_idx ].inc (
194+ spec_decoding_stats .num_drafts )
195+ self .counter_spec_decode_num_draft_tokens [engine_idx ].inc (
186196 spec_decoding_stats .num_draft_tokens )
187- self .counter_spec_decode_num_accepted_tokens .inc (
197+ self .counter_spec_decode_num_accepted_tokens [ engine_idx ] .inc (
188198 spec_decoding_stats .num_accepted_tokens )
189199 for pos , counter in enumerate (
190- self .counter_spec_decode_num_accepted_tokens_per_pos ):
200+ self .
201+ counter_spec_decode_num_accepted_tokens_per_pos [engine_idx ]):
191202 counter .inc (spec_decoding_stats .num_accepted_tokens_per_pos [pos ])
203+
204+
205+ def make_per_engine (counter : prometheus_client .Counter ,
206+ per_engine_labelvalues : dict [int , list [str ]]):
207+ """Create a counter for each label value."""
208+ return {
209+ idx : counter .labels (* labelvalues )
210+ for idx , labelvalues in per_engine_labelvalues .items ()
211+ }
0 commit comments