Skip to content

Conversation

@jeejeelee
Copy link
Collaborator

@jeejeelee jeejeelee commented Dec 14, 2025

Purpose

Test Plan

LoRA tests in CI , and we also can test it locally by the following script:

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
    "How old are you ?",
]

LORA_NAME_PATH_MAP = {
    "Alice": LoRARequest("Alice", 1, "charent/self_cognition_Alice"),
    "Bob": LoRARequest("Bob", 2, "charent/self_cognition_Bob"),
    "Cat": LoRARequest("Bob", 3, "charent/self_cognition_Bob"),
}


# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)


def main():
    # Create an LLM.
    llm = LLM(model="Qwen/Qwen3-0.6B", enable_lora=True)

    lora_req_lst = [
        LORA_NAME_PATH_MAP["Alice"],
        LORA_NAME_PATH_MAP["Bob"],
        None,
        LORA_NAME_PATH_MAP["Cat"],
        LORA_NAME_PATH_MAP["Alice"],
    ]

    outputs = llm.generate(prompts, sampling_params, lora_request=lora_req_lst)
    for output in outputs:
        lora_req = output.lora_request
        print(lora_req)
        print("-" * 60)


if __name__ == "__main__":
    main()

Test Result

test_llama_tp.py in CI should pass correctly, the script should print information as follow

LoRARequest(lora_name='Alice', lora_int_id=1, lora_path='charent/self_cognition_Alice', lora_local_path=None, long_lora_max_len=None, base_model_name=None, tensorizer_config_dict=None)
------------------------------------------------------------
LoRARequest(lora_name='Bob', lora_int_id=2, lora_path='charent/self_cognition_Bob', lora_local_path=None, long_lora_max_len=None, base_model_name=None, tensorizer_config_dict=None)
------------------------------------------------------------
None
------------------------------------------------------------
LoRARequest(lora_name='Bob', lora_int_id=3, lora_path='charent/self_cognition_Bob', lora_local_path=None, long_lora_max_len=None, base_model_name=None, tensorizer_config_dict=None)
------------------------------------------------------------
LoRARequest(lora_name='Alice', lora_int_id=1, lora_path='charent/self_cognition_Alice', lora_local_path=None, long_lora_max_len=None, base_model_name=None, tensorizer_config_dict=None)
------------------------------------------------------------

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Jee Jee Li <[email protected]>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify bot added llama Related to Llama models v1 labels Dec 14, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where the lora_request information was missing from the RequestOutput. The changes correctly propagate the full LoRARequest object through the RequestState to the final RequestOutput. The fix is validated by an updated test in tests/lora/test_llama_tp.py, which now asserts that the lora_request is correctly populated in the output. The implementation is clean and effectively resolves the issue.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) December 14, 2025 08:20
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 14, 2025
@LucasWilkinson
Copy link
Collaborator

Trying to fix this here too: #30744 but this makes sense in the interim; thanks!

@jeejeelee
Copy link
Collaborator Author

@LucasWilkinson Yeah, my changes are mainly aimed at minimizing edge-case OOM issues.

@vllm-bot vllm-bot merged commit 0e391e7 into vllm-project:main Dec 16, 2025
48 of 50 checks passed
@jeejeelee jeejeelee deleted the fix-lora-output branch December 16, 2025 09:38
markmc added a commit to markmc/vllm that referenced this pull request Dec 16, 2025
Since `LoRARequest` is part of the public API - e.g. as a parameter to `LLM.generate()`
or since vllm-project#30636 a member of `RequestOutput` - this warning looks outdated:

```python
class LoRARequest(...):
    """
    Request for a LoRA adapter.

    Note that this class should be used internally. For online
    serving, it is recommended to not allow users to use this class but
    instead provide another layer of abstraction to prevent users from
    accessing unauthorized LoRA adapters.
    ...
    """
```

Indeed, `LoRAReqest` seems to have been part of the public API since it was added in vllm-project#1804

Signed-off-by: Mark McLoughlin <[email protected]>
weiyu0824 pushed a commit to weiyu0824/vllm that referenced this pull request Dec 16, 2025
TheCodeWrangler pushed a commit to TheCodeWrangler/vllm that referenced this pull request Dec 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants