Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_nans_in_logits=request.num_nans_in_logits,
)
)
else:
Expand Down Expand Up @@ -1255,7 +1256,6 @@ def make_stats(
prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
spec_decoding_stats=spec_decoding_stats,
num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running),
kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None,
)

Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class EngineCoreOutput(
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0

# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
num_nans_in_logits: int = 0

@property
def finished(self) -> bool:
return self.finish_reason is not None
Expand Down
28 changes: 24 additions & 4 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from prometheus_client import Counter, Gauge, Histogram

import vllm.envs as envs
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
Expand Down Expand Up @@ -116,11 +117,13 @@ def _reset(self, now):
# Tracked stats over current local logging interval.
self.num_prompt_tokens: int = 0
self.num_generation_tokens: int = 0
self.num_corrupted_reqs: int = 0

def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
self.num_generation_tokens += iteration_stats.num_generation_tokens
self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs

def _get_throughput(self, tracked_stats: int, now: float) -> float:
# Compute summary metrics for tracked stats
Expand Down Expand Up @@ -204,6 +207,10 @@ def log(self):
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
]

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
log_parts.append("Corrupted: %d reqs")
log_args.append(self.num_corrupted_reqs)
if not self.connector_prefix_caching_metrics.empty:
log_parts.append("External prefix cache hit rate: %.1f%%")
log_args.append(self.connector_prefix_caching_metrics.hit_rate * 100)
Expand Down Expand Up @@ -275,9 +282,6 @@ def aggregate_scheduler_stats(self):
self.last_scheduler_stats.num_running_reqs += (
last_scheduler_stats.num_running_reqs
)
self.last_scheduler_stats.num_corrupted_reqs += (
last_scheduler_stats.num_corrupted_reqs
)
self.last_scheduler_stats.kv_cache_usage += (
last_scheduler_stats.kv_cache_usage
)
Expand Down Expand Up @@ -481,6 +485,19 @@ def __init__(
gauge_kv_cache_usage, engine_indexes, model_name
)

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
counter_corrupted_requests = self._counter_cls(
name="vllm:corrupted_requests",
documentation=(
"Corrupted requests, in terms of total number of requests "
"with NaNs in logits."
),
labelnames=labelnames,
)
self.counter_corrupted_requests = make_per_engine(
counter_corrupted_requests, engine_indexes, model_name
)

counter_prefix_cache_queries = self._counter_cls(
name="vllm:prefix_cache_queries",
documentation=(
Expand Down Expand Up @@ -979,7 +996,10 @@ def record(

if iteration_stats is None:
return

if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
self.counter_corrupted_requests[engine_idx].inc(
iteration_stats.num_corrupted_reqs
)
self.counter_num_preempted_reqs[engine_idx].inc(
iteration_stats.num_preempted_reqs
)
Expand Down
22 changes: 20 additions & 2 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import vllm.envs as envs
from vllm.v1.spec_decode.metrics import SpecDecodingStats

if TYPE_CHECKING:
Expand Down Expand Up @@ -169,8 +170,6 @@ class SchedulerStats:
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats: dict[str, Any] | None = None

num_corrupted_reqs: int = 0


@dataclass
class LoRAStats:
Expand All @@ -196,6 +195,9 @@ class RequestStateStats:
# first token latency
first_token_latency: float = 0.0

# Track if this request is corrupted (NaNs in logits)
is_corrupted: bool = False


@dataclass
class FinishedRequestStats:
Expand All @@ -211,6 +213,7 @@ class FinishedRequestStats:
inference_time: float = 0.0
decode_time: float = 0.0
mean_time_per_output_token: float = 0.0
is_corrupted: bool = False


class IterationStats:
Expand All @@ -228,6 +231,7 @@ def __init__(self):
self.inter_token_latencies_iter: list[float] = []
self.waiting_lora_adapters: dict[str, int] = {}
self.running_lora_adapters: dict[str, int] = {}
self.num_corrupted_reqs: int = 0

def __repr__(self) -> str:
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
Expand Down Expand Up @@ -258,6 +262,15 @@ def update_from_output(

req_stats.num_generation_tokens += num_new_generation_tokens

# Track if this request is corrupted (only check once per request)
# Early exit if already marked as corrupted to avoid redundant checks
if (
envs.VLLM_COMPUTE_NANS_IN_LOGITS
and not req_stats.is_corrupted
and output.num_nans_in_logits > 0
):
req_stats.is_corrupted = True

# Process request-level engine core events
if output.events is not None:
self.update_from_events(
Expand Down Expand Up @@ -339,9 +352,14 @@ def update_from_finished_request(
inference_time=inference_time,
decode_time=decode_time,
mean_time_per_output_token=mean_time_per_output_token,
is_corrupted=req_stats.is_corrupted,
)
self.finished_requests.append(finished_req)

# Count corrupted requests when they finish (only once per request)
if req_stats.is_corrupted:
self.num_corrupted_reqs += 1


class LoRARequestStates:
"""Per-LoRA request state stats."""
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ def append_output_token_ids(
def use_structured_output(self) -> bool:
return self.structured_output_request is not None

@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0

@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
Expand Down