Skip to content

Commit 0e391e7

Browse files
authored
[Bugfix] Fix RequestOutput miss lora_request (#30636)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent 0d0c929 commit 0e391e7

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

tests/lora/test_gptoss_tp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
7676
enable_lora=True,
7777
max_loras=4,
7878
max_lora_rank=8,
79+
max_num_seqs=2,
80+
max_num_batched_tokens=2048,
7981
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
8082
cudagraph_specialize_lora=False,
8183
),
@@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
9496
enable_lora=True,
9597
max_loras=2,
9698
max_lora_rank=8,
97-
max_num_seqs=16,
99+
max_num_seqs=2,
100+
max_num_batched_tokens=2048,
98101
tensor_parallel_size=2,
102+
gpu_memory_utilization=0.8,
99103
fully_sharded_loras=fully_sharded_loras,
100104
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
101105
cudagraph_specialize_lora=False,

tests/lora/test_llama_tp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,18 @@ def do_sample(
7676
if lora_id
7777
else None,
7878
)
79-
# Print the outputs.
79+
lora_request = LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None
8080
generated_texts: list[str] = []
8181
for output in outputs:
8282
prompt = output.prompt
8383
generated_text = output.outputs[0].text
84+
# The output should include correct lora_request info
85+
if lora_request is not None:
86+
assert output.lora_request.lora_name == lora_request.lora_name
87+
assert output.lora_request.lora_int_id == lora_request.lora_int_id
88+
assert output.lora_request.lora_path == lora_request.lora_path
89+
else:
90+
assert output.lora_request is None
8491
generated_texts.append(generated_text)
8592
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
8693
return generated_texts

vllm/v1/engine/output_processor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
from vllm.lora.request import LoRARequest
1112
from vllm.outputs import (
1213
CompletionOutput,
1314
PoolingOutput,
@@ -93,7 +94,7 @@ def __init__(
9394
request_id: str,
9495
parent_req: ParentRequest | None,
9596
request_index: int,
96-
lora_name: str | None,
97+
lora_request: LoRARequest | None,
9798
output_kind: RequestOutputKind,
9899
prompt: str | None,
99100
prompt_token_ids: list[int] | None,
@@ -112,7 +113,8 @@ def __init__(
112113
self.request_id = request_id
113114
self.parent_req = parent_req
114115
self.request_index = request_index
115-
self.lora_name = lora_name
116+
self.lora_request = lora_request
117+
self.lora_name = lora_request.lora_name if lora_request is not None else None
116118
self.output_kind = output_kind
117119
self.prompt = prompt
118120
self.prompt_token_ids = prompt_token_ids
@@ -178,9 +180,7 @@ def from_new_request(
178180
request_id=request.request_id,
179181
parent_req=parent_req,
180182
request_index=request_index,
181-
lora_name=(
182-
request.lora_request.name if request.lora_request is not None else None
183-
),
183+
lora_request=request.lora_request,
184184
output_kind=output_kind,
185185
prompt=prompt,
186186
prompt_token_ids=request.prompt_token_ids,
@@ -289,6 +289,7 @@ def _new_request_output(
289289

290290
return RequestOutput(
291291
request_id=request_id,
292+
lora_request=self.lora_request,
292293
prompt=self.prompt,
293294
prompt_token_ids=prompt_token_ids,
294295
prompt_logprobs=prompt_logprobs,

0 commit comments

Comments
 (0)