diff --git a/tests/lora/test_gptoss_tp.py b/tests/lora/test_gptoss_tp.py index f4269750feb6..2fa61f280587 100644 --- a/tests/lora/test_gptoss_tp.py +++ b/tests/lora/test_gptoss_tp.py @@ -76,6 +76,8 @@ def test_gpt_oss_lora(gptoss20b_lora_files): enable_lora=True, max_loras=4, max_lora_rank=8, + max_num_seqs=2, + max_num_batched_tokens=2048, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, ), @@ -94,8 +96,10 @@ def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras): enable_lora=True, max_loras=2, max_lora_rank=8, - max_num_seqs=16, + max_num_seqs=2, + max_num_batched_tokens=2048, tensor_parallel_size=2, + gpu_memory_utilization=0.8, fully_sharded_loras=fully_sharded_loras, compilation_config=vllm.config.CompilationConfig( # Avoid OOM cudagraph_specialize_lora=False, diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 18704fa6e45d..483235ff5129 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -76,11 +76,18 @@ def do_sample( if lora_id else None, ) - # Print the outputs. + lora_request = LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None generated_texts: list[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text + # The output should include correct lora_request info + if lora_request is not None: + assert output.lora_request.lora_name == lora_request.lora_name + assert output.lora_request.lora_int_id == lora_request.lora_int_id + assert output.lora_request.lora_path == lora_request.lora_path + else: + assert output.lora_request is None generated_texts.append(generated_text) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_texts diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 9be3f4da7352..8f7d8a71f1a2 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,6 +8,7 @@ import torch +from vllm.lora.request import LoRARequest from vllm.outputs import ( CompletionOutput, PoolingOutput, @@ -93,7 +94,7 @@ def __init__( request_id: str, parent_req: ParentRequest | None, request_index: int, - lora_name: str | None, + lora_request: LoRARequest | None, output_kind: RequestOutputKind, prompt: str | None, prompt_token_ids: list[int] | None, @@ -112,7 +113,8 @@ def __init__( self.request_id = request_id self.parent_req = parent_req self.request_index = request_index - self.lora_name = lora_name + self.lora_request = lora_request + self.lora_name = lora_request.lora_name if lora_request is not None else None self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -178,9 +180,7 @@ def from_new_request( request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=( - request.lora_request.name if request.lora_request is not None else None - ), + lora_request=request.lora_request, output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, @@ -289,6 +289,7 @@ def _new_request_output( return RequestOutput( request_id=request_id, + lora_request=self.lora_request, prompt=self.prompt, prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs,