Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e63fe82
Prompt Embeddings Support for v1 Engine
jesse996 Sep 19, 2025
3f29bec
merge
jesse996 Sep 22, 2025
777046f
[Fix] Update input embeddings condition to include prompt embeddings …
jesse996 Sep 22, 2025
c522293
fix param
jesse996 Sep 22, 2025
f239a9b
merge main
jesse996 Sep 25, 2025
a8187a5
format
jesse996 Sep 26, 2025
75cfdd7
add test
jesse996 Sep 26, 2025
6548360
merge main
jesse996 Sep 26, 2025
6d47582
fix test
jesse996 Oct 9, 2025
f706331
fix test
jesse996 Oct 9, 2025
7d6f819
fix test
jesse996 Oct 10, 2025
24706a5
fix test
jesse996 Oct 10, 2025
1526fc1
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 Oct 10, 2025
cff8886
fix test
jesse996 Oct 10, 2025
4daf970
fix test
jesse996 Oct 11, 2025
41b6cdc
fix test
jesse996 Oct 11, 2025
565e3bf
fix test
jesse996 Oct 11, 2025
b84600a
fix test
jesse996 Oct 13, 2025
8973cd3
fix test
jesse996 Oct 13, 2025
a32adc3
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 Oct 15, 2025
e41c84a
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 Oct 16, 2025
aef7626
fix test
jesse996 Oct 16, 2025
60dbffe
fix code
jesse996 Oct 16, 2025
6a5ea17
add example
jesse996 Oct 16, 2025
75ee4a2
fix
jesse996 Oct 16, 2025
7f7f992
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 Oct 21, 2025
cf0f217
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 Oct 22, 2025
b072b3c
fix comment
jesse996 Oct 23, 2025
be3fe4c
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 Oct 24, 2025
dfada67
remove unused
jesse996 Oct 24, 2025
d1327c6
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 Oct 26, 2025
407bf75
add test to workflows
jesse996 Oct 26, 2025
d3b9fbb
fix test
jesse996 Oct 26, 2025
60256c4
fix test
jesse996 Oct 26, 2025
b4e4098
fix test
jesse996 Oct 26, 2025
46358ed
fix test
jesse996 Oct 26, 2025
c74dc8f
fix test
jesse996 Oct 27, 2025
e5a5c3c
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 Oct 28, 2025
256414b
fix test
jesse996 Oct 28, 2025
182c911
fix test
jesse996 Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 71 additions & 8 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available)
is_pin_memory_available,
length_from_prompt_token_ids_or_embeds)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
Expand Down Expand Up @@ -285,11 +286,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):

self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
dtype=self.dtype,
device=self.device)
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The inputs_embeds buffer is only initialized for multimodal models. However, it is also required when enable_prompt_embeds is true for non-multimodal models. Without this initialization, an AttributeError will be raised when self.inputs_embeds is accessed later in methods like _dummy_run or _prepare_input_ids. The condition should be updated to include self.enable_prompt_embeds.

Suggested change
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
dtype=self.dtype,
device=self.device)
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
numpy=False)

self.is_token_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)

# Set up Attention
self.attn_backend = get_attn_backend(
Expand Down Expand Up @@ -572,6 +576,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
Expand Down Expand Up @@ -843,7 +848,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds)

if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
Expand Down Expand Up @@ -1016,6 +1022,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return

# Async scheduling case, where some decode requests from the previous
Expand Down Expand Up @@ -1043,6 +1051,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
Expand All @@ -1056,6 +1066,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
Expand Down Expand Up @@ -1195,15 +1206,60 @@ def _prepare_inputs(
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])

token_indices_tensor = torch.from_numpy(token_indices)
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
token_indices_tensor,
out=self.input_ids_cpu[:total_num_scheduled_tokens])
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])

# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
if self.input_batch.req_prompt_embeds:
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]

# Skip if this request doesn't have embeddings
if req_idx not in self.input_batch.req_prompt_embeds:
output_idx += num_sched
continue

# Skip if no tokens scheduled
if num_sched <= 0:
output_idx += num_sched
continue

req_embeds = self.input_batch.req_prompt_embeds[req_idx]
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]

# Skip if trying to read beyond available embeddings
if start_pos >= req_embeds.shape[0]:
output_idx += num_sched
continue

# Copy available embeddings
end_pos = start_pos + num_sched
actual_end = min(end_pos, req_embeds.shape[0])
actual_num_sched = actual_end - start_pos

if actual_num_sched > 0:
self.inputs_embeds.cpu[output_idx:output_idx +
actual_num_sched].copy_(
req_embeds[start_pos:actual_end]
)

output_idx += num_sched

# Prepare some information for building Attention-Metadata
# Compute and commit slot mapping
Expand Down Expand Up @@ -1985,6 +2041,7 @@ def execute_model(

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
Expand Down Expand Up @@ -2200,6 +2257,9 @@ def _dummy_run(
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
Expand Down Expand Up @@ -3070,6 +3130,9 @@ def _get_prompt_logprobs_dict(

# Get metadata for this request.
request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
Expand Down
55 changes: 48 additions & 7 deletions vllm_ascend/worker/npu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds,swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
Expand All @@ -45,7 +45,7 @@
class CachedRequestState:

req_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
Expand All @@ -61,9 +61,11 @@ class CachedRequestState:
mrope_position_delta: Optional[int] = None

lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None

def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)

@property
def num_tokens(self) -> int:
Expand All @@ -78,6 +80,10 @@ def mm_inputs(self) -> list[MultiModalKwargs]:

def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
Expand Down Expand Up @@ -122,6 +128,14 @@ def __init__(
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
Expand Down Expand Up @@ -326,15 +340,23 @@ def add_request(
self.req_id_to_index[req_id] = req_index

# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
Expand Down Expand Up @@ -534,6 +556,20 @@ def swap_states(self, i1: int, i2: int) -> None:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp

self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]

# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)

swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)

Expand Down Expand Up @@ -612,6 +648,11 @@ def condense(self) -> None:
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
Expand Down
Loading