Skip to content

Conversation

@zzhx1
Copy link
Contributor

@zzhx1 zzhx1 commented Nov 13, 2025

What this PR does / why we need it?

It is mentioned in the flashcomm2 technical report that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication.

We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue.

This PR depends on #3232 and #2931

Flashcomm2 flowchart

PixPin_2025-11-14_13-37-39

Does this PR introduce any user-facing change?

Use environment variables

export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1
export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for shared o_proj linear layers for Flashcomm2, which involves changes across configuration, distributed state management, and the attention mechanism. The core logic for shared weights is implemented in vllm_ascend/torchair/ops/shared_weight_layer.py, which has been refactored for better usability.

My review focuses on ensuring the correctness and robustness of the new feature. I've identified a few critical issues:

  • Incorrect validation logic for the new flashcomm2_oproj_shared configuration that could lead to silent failures.
  • A potential crash in the shared weight layer logic when handling a series with a single layer.

I have provided suggestions to fix these issues. The rest of the changes look good and the refactoring of the shared weight layer API is a nice improvement.

Comment on lines 137 to 140
if self.flashcomm2_oproj_tensor_parallel_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validation if self.flashcomm2_oproj_tensor_parallel_size is None: is incorrect. The value of self.flashcomm2_oproj_tensor_parallel_size is an integer returned from get_flashcomm2_config_and_validate (which gets it from an environment variable with a default of 0), so it will never be None. The check should be against 0, as flashcomm2_oproj_shared requires flashcomm2_oproj_tensor_parallel_size to be greater than 0.

Suggested change
if self.flashcomm2_oproj_tensor_parallel_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
if self.flashcomm2_oproj_tensor_parallel_size == 0:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled with flashcomm2_oproj_tensor_parallel_size > 0"
)

self.layers.sort(key=lambda x: x.layer_idx)
self.num_layers = len(self.layers)
assert self.num_layers > 0, "No layers in the series"
assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion self.prefetch_step <= self.num_layers - 2 will cause a crash if a shared weight series contains only one layer (self.num_layers == 1), because self.num_layers - 2 would be -1. For a single-layer series, prefetching is not applicable, and prefetch_step should be 0. To prevent this crash, the assertion should be adjusted to handle this edge case.

Suggested change
assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]"
assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), "prefetch_step must be in [0, num_layers - 2]"

Comment on lines 868 to 873
if flashcomm2_oproj_shared:
if flashcomm2_oproj_tp_size is None:
raise AssertionError(
"flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size"
)
logger.info("Enable Flashcomm2 with flashcomm2_oproj_shared")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This validation logic for flashcomm2_oproj_shared is redundant with the logic in vllm_ascend/ascend_config.py. It's better to have validation in one place to avoid inconsistencies. Since ascend_config.py is the configuration entry point, it's a better place for this check. Additionally, the check if flashcomm2_oproj_tp_size is None: is incorrect, as flashcomm2_oproj_tp_size is an integer. I've suggested a fix in ascend_config.py and recommend removing this redundant block.

@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from a951ad1 to da0d630 Compare November 14, 2025 05:05
@zzhx1 zzhx1 changed the title Flashcomm2 use o_shared linear [Feat] Flashcomm2 use o_shared linear Nov 14, 2025
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 3 times, most recently from 47501e5 to 8bf8ed5 Compare November 14, 2025 09:14
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 2 times, most recently from faaa68e to bfbda42 Compare November 17, 2025 05:08
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from 60f1aac to 89c3923 Compare November 24, 2025 07:12
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 3 times, most recently from 18803f9 to a9fae57 Compare December 1, 2025 08:28
@github-actions
Copy link

github-actions bot commented Dec 1, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 4 times, most recently from fd7c9fa to ba1a760 Compare December 4, 2025 07:40
@zzhx1
Copy link
Contributor Author

zzhx1 commented Dec 4, 2025

@wangxiyuan this PR is ready, Please check again, and if there are no issues, please help merge it in.

@jianzs jianzs added ready read for review ready-for-test start test by label for PR labels Dec 4, 2025
kv_no_split.contiguous(), need_gather_q_kv)

if self.fc2_enable and is_hidden_layer(self.vllm_config, self.o_proj):
reach_layer_for_shared_weight_series(self.o_proj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the first broadcast performed here? I think it's not general enough because other models in the profilerun phase are not aware of the information related to the Opoj layer. Should the first broadcast be performed after post_process_after_loading_for_shared_weight_series instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's already included in post_process_after_loading_for_shared_weight_series. See https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/torchair/ops/shared_weight_layer.py#L73

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK ,understand. So why we do broadcast in profile run if the first broadcast already included in post_process_after_loading_for_shared_weight_series ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To handle the multi DP cases. When some DP are running dummy_run, they should also broadcast their weights to those DP executing model.

@github-actions
Copy link

github-actions bot commented Dec 5, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 6 times, most recently from 8404356 to 5f7c45c Compare December 6, 2025 03:51
from typing import Callable, Optional

import torch
import torch.distributed as dist
Copy link
Collaborator

@wangxiyuan wangxiyuan Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a ops, we should consider to move to a better place

@github-actions
Copy link

github-actions bot commented Dec 6, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: zzhx1 <[email protected]>
Co-authored-by: clrs97 <[email protected]>
Co-authored-by: Levi-JQ <[email protected]>
Signed-off-by: zzhx1 <[email protected]>
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch 2 times, most recently from 2427a23 to 0835a03 Compare December 6, 2025 17:39
Signed-off-by: zzhx1 <[email protected]>
@zzhx1 zzhx1 force-pushed the flashcomm_oshared branch from 0835a03 to 3a1bf01 Compare December 6, 2025 18:04
@zzhx1
Copy link
Contributor Author

zzhx1 commented Dec 7, 2025

@wangxiyuan this PR is ready, please help merge it in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants