-
Notifications
You must be signed in to change notification settings - Fork 657
[feature] Prompt Embeddings Support for v1 Engine #3026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wangxiyuan
merged 40 commits into
vllm-project:main
from
jesse996:enable-prompt-embeds-in-v1
Oct 30, 2025
+447
−17
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
e63fe82
Prompt Embeddings Support for v1 Engine
jesse996 3f29bec
merge
jesse996 777046f
[Fix] Update input embeddings condition to include prompt embeddings …
jesse996 c522293
fix param
jesse996 f239a9b
merge main
jesse996 a8187a5
format
jesse996 75cfdd7
add test
jesse996 6548360
merge main
jesse996 6d47582
fix test
jesse996 f706331
fix test
jesse996 7d6f819
fix test
jesse996 24706a5
fix test
jesse996 1526fc1
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 cff8886
fix test
jesse996 4daf970
fix test
jesse996 41b6cdc
fix test
jesse996 565e3bf
fix test
jesse996 b84600a
fix test
jesse996 8973cd3
fix test
jesse996 a32adc3
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 e41c84a
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 aef7626
fix test
jesse996 60dbffe
fix code
jesse996 6a5ea17
add example
jesse996 75ee4a2
fix
jesse996 7f7f992
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 cf0f217
Merge branch 'vllm-project:main' into enable-prompt-embeds-in-v1
jesse996 b072b3c
fix comment
jesse996 be3fe4c
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 dfada67
remove unused
jesse996 d1327c6
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 407bf75
add test to workflows
jesse996 d3b9fbb
fix test
jesse996 60256c4
fix test
jesse996 b4e4098
fix test
jesse996 46358ed
fix test
jesse996 c74dc8f
fix test
jesse996 e5a5c3c
Merge branch 'main' into enable-prompt-embeds-in-v1
jesse996 256414b
fix test
jesse996 182c911
fix test
jesse996 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,7 +67,8 @@ | |
| from vllm.tasks import GenerationTask, PoolingTask, SupportedTask | ||
| from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, | ||
| LazyLoader, cdiv, get_dtype_size, | ||
| is_pin_memory_available) | ||
| is_pin_memory_available, | ||
| length_from_prompt_token_ids_or_embeds) | ||
| from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder | ||
| from vllm.v1.attention.backends.utils import \ | ||
| reorder_batch_to_split_decodes_and_prefills | ||
|
|
@@ -285,11 +286,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |
|
|
||
| self.is_multimodal_model = self.model_config.is_multimodal_model | ||
| self.is_pooling_model = self.model_config.pooler_config is not None | ||
| self.enable_prompt_embeds = self.model_config.enable_prompt_embeds | ||
| if self.is_multimodal_model: | ||
| self.inputs_embeds = torch.zeros( | ||
| (self.max_num_tokens, self.model_config.get_hidden_size()), | ||
| dtype=self.dtype, | ||
| device=self.device) | ||
| self.inputs_embeds = self._make_buffer(self.max_num_tokens, | ||
| self.model_config.get_hidden_size(), | ||
| dtype=self.dtype, | ||
| numpy=False) | ||
| self.is_token_ids = self._make_buffer(self.max_num_tokens, | ||
| dtype=torch.bool) | ||
|
|
||
| # Set up Attention | ||
| self.attn_backend = get_attn_backend( | ||
|
|
@@ -572,6 +576,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: | |
| self.requests[req_id] = CachedRequestState( | ||
| req_id=req_id, | ||
| prompt_token_ids=new_req_data.prompt_token_ids, | ||
| prompt_embeds=new_req_data.prompt_embeds, | ||
| mm_kwargs=new_req_data.mm_kwargs, | ||
| mm_positions=new_req_data.mm_positions, | ||
| sampling_params=sampling_params, | ||
|
|
@@ -843,7 +848,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): | |
| self.input_batch.num_computed_tokens_cpu[index] | ||
| num_scheduled_tokens = \ | ||
| scheduler_output.num_scheduled_tokens[req_id] | ||
| num_prompt_tokens = len(req.prompt_token_ids) | ||
| num_prompt_tokens = length_from_prompt_token_ids_or_embeds( | ||
| req.prompt_token_ids, req.prompt_embeds) | ||
|
|
||
| if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: | ||
| prompt_part_len = max(0, | ||
|
|
@@ -1016,6 +1022,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, | |
| self.input_ids[:total_num_scheduled_tokens].copy_( | ||
| self.input_ids_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
| self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) | ||
| self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) | ||
| return | ||
|
|
||
| # Async scheduling case, where some decode requests from the previous | ||
|
|
@@ -1043,6 +1051,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, | |
| self.input_ids[:total_num_scheduled_tokens].copy_( | ||
| self.input_ids_cpu[:total_num_scheduled_tokens], | ||
| non_blocking=True) | ||
| self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) | ||
| self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) | ||
| if num_commmon_tokens == 0: | ||
| # No requests in common with the previous iteration | ||
| # So input_ids_cpu will have all the input ids. | ||
|
|
@@ -1056,6 +1066,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, | |
| self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, | ||
| 0], | ||
| non_blocking=True) | ||
| self.is_token_ids.gpu[:num_commmon_tokens] = True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| return | ||
| # Upload the index tensors asynchronously | ||
| # so the scatter can be non-blocking. | ||
|
|
@@ -1195,15 +1206,60 @@ def _prepare_inputs( | |
| # where M is the max_model_len. | ||
| token_indices = (positions_np + | ||
| req_indices * self.input_batch.token_ids_cpu.shape[1]) | ||
|
|
||
| token_indices_tensor = torch.from_numpy(token_indices) | ||
| # Prepare input_ids. | ||
| # NOTE(woosuk): We use torch.index_select instead of np.take here | ||
| # because torch.index_select is much faster than np.take for large | ||
| # tensors. | ||
| torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), | ||
| 0, | ||
| torch.from_numpy(token_indices), | ||
| token_indices_tensor, | ||
| out=self.input_ids_cpu[:total_num_scheduled_tokens]) | ||
| is_token_ids = self.input_batch.is_token_ids.flatten() | ||
| torch.index_select( | ||
| is_token_ids, | ||
| 0, | ||
| token_indices_tensor, | ||
| out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) | ||
|
|
||
| # Because we did not pre-allocate a massive prompt_embeds CPU tensor on | ||
| # the InputBatch, we need to fill in the prompt embeds into the expected | ||
| # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. | ||
| if self.input_batch.req_prompt_embeds: | ||
| output_idx = 0 | ||
| for req_idx in range(num_reqs): | ||
| num_sched = num_scheduled_tokens[req_idx] | ||
|
|
||
| # Skip if this request doesn't have embeddings | ||
| if req_idx not in self.input_batch.req_prompt_embeds: | ||
| output_idx += num_sched | ||
| continue | ||
|
|
||
| # Skip if no tokens scheduled | ||
| if num_sched <= 0: | ||
| output_idx += num_sched | ||
| continue | ||
|
|
||
| req_embeds = self.input_batch.req_prompt_embeds[req_idx] | ||
| start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] | ||
|
|
||
| # Skip if trying to read beyond available embeddings | ||
| if start_pos >= req_embeds.shape[0]: | ||
| output_idx += num_sched | ||
| continue | ||
|
|
||
| # Copy available embeddings | ||
| end_pos = start_pos + num_sched | ||
| actual_end = min(end_pos, req_embeds.shape[0]) | ||
| actual_num_sched = actual_end - start_pos | ||
|
|
||
| if actual_num_sched > 0: | ||
| self.inputs_embeds.cpu[output_idx:output_idx + | ||
| actual_num_sched].copy_( | ||
| req_embeds[start_pos:actual_end] | ||
| ) | ||
|
|
||
| output_idx += num_sched | ||
|
|
||
| # Prepare some information for building Attention-Metadata | ||
| # Compute and commit slot mapping | ||
|
|
@@ -1985,6 +2041,7 @@ def execute_model( | |
|
|
||
| self.input_batch.token_ids_cpu[req_idx, | ||
| start_idx:end_idx] = sampled_ids | ||
| self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True | ||
| self.input_batch.num_tokens_no_spec[req_idx] = end_idx | ||
| self.input_batch.num_tokens[req_idx] = end_idx | ||
| req_id = self.input_batch.req_ids[req_idx] | ||
|
|
@@ -2200,6 +2257,9 @@ def _dummy_run( | |
| if self.is_multimodal_model: | ||
| input_ids = None | ||
| inputs_embeds = self.inputs_embeds[:num_tokens] | ||
| elif self.enable_prompt_embeds: | ||
| input_ids = None | ||
| inputs_embeds = self.inputs_embeds.gpu[:num_tokens] | ||
| else: | ||
| input_ids = self.input_ids[:num_tokens] | ||
| inputs_embeds = None | ||
|
|
@@ -3070,6 +3130,9 @@ def _get_prompt_logprobs_dict( | |
|
|
||
| # Get metadata for this request. | ||
| request = self.requests[req_id] | ||
| if request.prompt_token_ids is None: | ||
| # Prompt logprobs is incompatible with prompt embeddings | ||
| continue | ||
| num_prompt_tokens = len(request.prompt_token_ids) | ||
| prompt_token_ids = torch.tensor(request.prompt_token_ids).to( | ||
| self.device, non_blocking=True) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
inputs_embedsbuffer is only initialized for multimodal models. However, it is also required whenenable_prompt_embedsis true for non-multimodal models. Without this initialization, anAttributeErrorwill be raised whenself.inputs_embedsis accessed later in methods like_dummy_runor_prepare_input_ids. The condition should be updated to includeself.enable_prompt_embeds.