Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_nans_in_logits=request.num_nans_in_logits,
)
)
else:
Expand Down Expand Up @@ -1259,7 +1260,6 @@ def make_stats(
prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running),
kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None,
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class EngineCoreOutput(
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0

# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
num_nans_in_logits: int = 0

@property
def finished(self) -> bool:
return self.finish_reason is not None
Expand Down
28 changes: 24 additions & 4 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from prometheus_client import Counter, Gauge, Histogram

import vllm.envs as envs
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.logger import init_logger
Expand Down Expand Up @@ -110,11 +111,13 @@ def _reset(self, now):
# Tracked stats over current local logging interval.
self.num_prompt_tokens: int = 0
self.num_generation_tokens: int = 0
self.num_corrupted_reqs: int = 0

def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
self.num_generation_tokens += iteration_stats.num_generation_tokens
self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs

def _get_throughput(self, tracked_stats: int, now: float) -> float:
# Compute summary metrics for tracked stats
Expand Down Expand Up @@ -198,6 +201,10 @@ def log(self):
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
]

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
log_parts.append("Corrupted: %d reqs")
log_args.append(self.num_corrupted_reqs)
if not self.connector_prefix_caching_metrics.empty:
log_parts.append("External prefix cache hit rate: %.1f%%")
log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100)
Expand Down Expand Up @@ -269,9 +276,6 @@ def aggregate_scheduler_stats(self):
self.last_scheduler_stats.num_running_reqs += (
last_scheduler_stats.num_running_reqs
)
self.last_scheduler_stats.num_corrupted_reqs += (
last_scheduler_stats.num_corrupted_reqs
)
self.last_scheduler_stats.kv_cache_usage += (
last_scheduler_stats.kv_cache_usage
)
Expand Down Expand Up @@ -446,6 +450,19 @@ def __init__(
gauge_kv_cache_usage, engine_indexes, model_name
)

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
counter_corrupted_requests = self._counter_cls(
name="vllm:corrupted_requests",
documentation=(
"Corrupted requests, in terms of total number of requests "
"with NaNs in logits."
),
labelnames=labelnames,
)
self.counter_corrupted_requests = make_per_engine(
counter_corrupted_requests, engine_indexes, model_name
)

counter_prefix_cache_queries = self._counter_cls(
name="vllm:prefix_cache_queries",
documentation=(
Expand Down Expand Up @@ -939,7 +956,10 @@ def record(

if iteration_stats is None:
return

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
self.counter_corrupted_requests[engine_idx].inc(
iteration_stats.num_corrupted_reqs
)
self.counter_num_preempted_reqs[engine_idx].inc(
iteration_stats.num_preempted_reqs
)
Expand Down
22 changes: 20 additions & 2 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import vllm.envs as envs
from vllm.v1.spec_decode.metrics import SpecDecodingStats

if TYPE_CHECKING:
Expand Down Expand Up @@ -169,8 +170,6 @@ class SchedulerStats:
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None

num_corrupted_reqs: int = 0


@dataclass
class LoRAStats:
Expand All @@ -196,6 +195,9 @@ class RequestStateStats:
# first token latency
first_token_latency: float = 0.0

# Track if this request is corrupted (NaNs in logits)
is_corrupted: bool = False


@dataclass
class FinishedRequestStats:
Expand All @@ -211,6 +213,7 @@ class FinishedRequestStats:
inference_time: float = 0.0
decode_time: float = 0.0
mean_time_per_output_token: float = 0.0
is_corrupted: bool = False


class IterationStats:
Expand All @@ -228,6 +231,7 @@ def __init__(self):
self.inter_token_latencies_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
self.num_corrupted_reqs: int = 0

def __repr__(self) -> str:
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
Expand Down Expand Up @@ -258,6 +262,15 @@ def update_from_output(

req_stats.num_generation_tokens += num_new_generation_tokens

# Track if this request is corrupted (only check once per request)
# Early exit if already marked as corrupted to avoid redundant checks
if (
envs.VLLM_COMPUTE_NANS_IN_LOGITS
and not req_stats.is_corrupted
and output.num_nans_in_logits > 0
):
req_stats.is_corrupted = True

# Process request-level engine core events
if output.events is not None:
self.update_from_events(
Expand Down Expand Up @@ -339,9 +352,14 @@ def update_from_finished_request(
inference_time=inference_time,
decode_time=decode_time,
mean_time_per_output_token=mean_time_per_output_token,
is_corrupted=req_stats.is_corrupted,
)
self.finished_requests.append(finished_req)

# Count corrupted requests when they finish (only once per request)
if req_stats.is_corrupted:
self.num_corrupted_reqs += 1


class LoRARequestStates:
"""Per-LoRA request state stats."""
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def append_output_token_ids(
def use_structured_output(self) -> bool:
return self.structured_output_request is not None

@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0

@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
Expand Down