Skip to content

Conversation

@LHXuuu
Copy link
Contributor

@LHXuuu LHXuuu commented Nov 6, 2025

What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community to generate quantized weights, the VLLM Ascend engine needs to be adapted to support the compressed tensors quantization format.

  1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig in vllm.
  2. Support CompressedTensorsW8A8 static weight.
    • weight: per-channel, int8, symmetric; activation: per-tensor, int8, symmetric.
  3. Support CompressedTensorsW8A8Dynamic weight.
    • weight: per-channel, int8, symmetric; activation: per-token, int8, symmetric, dynamic.
  4. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

Does this PR introduce any user-facing change?

No

How was this patch tested?

@github-actions
Copy link

github-actions bot commented Nov 6, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for w8a8 static and dynamic quantization using the compressed tensors format on Ascend hardware. The changes include a new AscendCompressedTensorsConfig, corresponding quantization schemes, and integration into the vLLM-Ascend platform and worker.

The implementation looks good overall, but I've found a few issues:

  • A critical bug in AscendCompressedTensorsConfig that could lead to a runtime crash due to a missing None check.
  • Some robustness issues, such as an unsafe list removal and the use of assert for configuration validation, which could cause crashes.
  • A performance issue in the w8a8 static quantization scheme where a transpose operation is inefficiently performed on every forward pass.

I've provided detailed comments and suggestions to address these points.

Comment on lines 111 to 113
if is_310p():
# On 300I Duo platform, we need transpose again if
# using nz. This transpose can be skipped in torchair.
output = torch_npu.npu_quant_matmul(
x,
layer.weight.data.transpose(1, 0),
layer.deq_scale,
bias=bias,
output_dtype=layer.params_dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The transpose operation on layer.weight.data is performed on every forward pass for the is_310p() case, which is inefficient. The transposed weight should be computed once and cached to improve performance. A good place for this one-time operation would be in process_weights_after_loading.

        if is_310p():
            # On 300I Duo platform, we need transpose again if
            # using nz. This transpose can be skipped in torchair.
            # The transpose is cached to avoid re-computation on every forward pass.
            if not hasattr(layer, "_weight_transposed_for_310p"):
                layer._weight_transposed_for_310p = layer.weight.data.transpose(1, 0).contiguous()
            output = torch_npu.npu_quant_matmul(
                x,
                layer._weight_transposed_for_310p,
                layer.deq_scale,
                bias=bias,
                output_dtype=layer.params_dtype,
            )

@MengqingCao
Copy link
Collaborator

MengqingCao commented Nov 7, 2025

Thanks for this great work! Could you plz add an e2e test of w8a8 static and dynamic quant? And ut is also expected, but we could add ut in the follow-up prs.

And is there any accuracy and performance mertics of your pr?

also cc @wangxiyuan @22dimensions

@MengqingCao
Copy link
Collaborator

You can solve the DCO and lint issues by referring to the contributing doc in https://vllm-ascend.readthedocs.io/

@LHXuuu LHXuuu closed this Nov 7, 2025
@LHXuuu LHXuuu reopened this Nov 7, 2025
@LHXuuu
Copy link
Contributor Author

LHXuuu commented Nov 7, 2025

Thanks for this great work! Could you plz add an e2e test of w8a8 static and dynamic quant? And ut is also expected, but we could add ut in the follow-up prs.

And is there any accuracy and performance mertics of your pr?

also cc @wangxiyuan @22dimensions

Thanks for your reply. I’m currently running accuracy and performance tests. Once they’re complete, I’ll post them in the comment.

@LHXuuu LHXuuu force-pushed the compressor_tensor branch 3 times, most recently from bfc2302 to ad6ab6d Compare November 11, 2025 08:28
@LHXuuu
Copy link
Contributor Author

LHXuuu commented Nov 12, 2025

@MengqingCao @wangxiyuan Hi! The precision results are shown in the table below. W8A8 static weights fall back to all down_proj linear, while w8a8 dynamic weights are fully quantized.

Qwen3-32b precision test

ceval gsm8k mmlu
BF16 88.94 95.45 89.27
w8a8 static 89.03 95.45 88.91
w8a8 dynamic 88.72 96.51 89.16

@LHXuuu LHXuuu force-pushed the compressor_tensor branch 2 times, most recently from 79c3d6a to d8b7eed Compare November 13, 2025 06:25
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Nov 18, 2025
@LHXuuu LHXuuu force-pushed the compressor_tensor branch 4 times, most recently from 402a5f2 to 9d969ee Compare November 20, 2025 06:50
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Nov 25, 2025
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
chenxi-hh and others added 8 commits November 26, 2025 14:32
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
@LHXuuu
Copy link
Contributor Author

LHXuuu commented Nov 28, 2025

@MengqingCao @wangxiyuan Hello, this pr is ready to merge.

@wangxiyuan wangxiyuan merged commit bdc6697 into vllm-project:main Nov 28, 2025
22 checks passed
ChenCangtao pushed a commit to ChenCangtao/vllm-ascend that referenced this pull request Dec 3, 2025
…c weight (vllm-project#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…c weight (vllm-project#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…c weight (vllm-project#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
"Falling back to UnquantizedLinearMethod")
return None

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant else, you could remove it.


@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's good to return (torch.int8, torch.float16, torch.bfloat16)


@classmethod
def get_config_filenames(cls) -> list[str]:
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's good to return ()

# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static

def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_is_dynamic_token_w8a8 and _is_static_tensor_w8a8 is duplicated, please extract a method

Meihan-chen pushed a commit to Meihan-chen/vllm-ascend that referenced this pull request Dec 5, 2025
…c weight (vllm-project#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
…c weight (vllm-project#4036)

### What this PR does / why we need it?

While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.

1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.

Co-authored-by: taoqun110 [email protected]
Co-authored-by: chenxi-hh [email protected]

- vLLM version: v0.11.2

---------

Signed-off-by: LHXuuu <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Signed-off-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Co-authored-by: chenxi-hh <[email protected]>
Signed-off-by: tanqingshan (A) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation module:core module:quantization module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants