diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 972abb39ae4..8e27f2dc707 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -2,4 +2,5 @@ self-hosted-runner: # Labels of self-hosted runner in array of strings. labels: - linux-arm64-npu-1 + - linux-arm64-npu-2 - linux-arm64-npu-4 diff --git a/.github/workflows/accuracy_report.yaml b/.github/workflows/accuracy_report.yaml new file mode 100644 index 00000000000..6da1f845db1 --- /dev/null +++ b/.github/workflows/accuracy_report.yaml @@ -0,0 +1,150 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +name: Accuracy Report +on: + workflow_dispatch: + inputs: + branch: + description: 'choose a dev branch to pr' + required: true + vllm-ascend-version: + description: 'what vllm-ascend version to accuracy test?' + required: true + type: string +jobs: + download: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.branch }} + + - name: Debug List Artifacts + run: gh api /repos/${{ github.repository }}/actions/artifacts + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Query artifact run id for Qwen2.5-VL-7B-Instruct V0 latest artifact + id: get_Qwen2_5_VL_7B_Instruct_latest_run_id_V0 + run: | + ARTIFACT_JSON=$(gh api "repos/${{ github.repository }}/actions/artifacts") + RUN_ID=$(echo "$ARTIFACT_JSON" | \ + jq -r '[.artifacts[] | select(.name=="${{ github.event.inputs.vllm-ascend-version }}-Qwen2.5-VL-7B-Instruct-V0-report")] | sort_by(.created_at) | last | .workflow_run.id') + echo "runid=$RUN_ID" >> "$GITHUB_OUTPUT" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Query artifact run id for Qwen2.5-7B-Instruct V0 latest artifact + id: get_Qwen2_5_7B_Instruct_latest_run_id_V0 + run: | + ARTIFACT_JSON=$(gh api "repos/${{ github.repository }}/actions/artifacts") + RUN_ID=$(echo "$ARTIFACT_JSON" | \ + jq -r '[.artifacts[] | select(.name=="${{ github.event.inputs.vllm-ascend-version }}-Qwen2.5-7B-Instruct-V0-report")] | sort_by(.created_at) | last | .workflow_run.id') + echo "runid=$RUN_ID" >> "$GITHUB_OUTPUT" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Query artifact run id for Llama-3.1-8B-Instruct V0 latest artifact + id: get_Llama_3_1_8B_Instruct_latest_run_id_V0 + run: | + ARTIFACT_JSON=$(gh api "repos/${{ github.repository }}/actions/artifacts") + RUN_ID=$(echo "$ARTIFACT_JSON" | \ + jq -r '[.artifacts[] | select(.name=="${{ github.event.inputs.vllm-ascend-version }}-Llama-3.1-8B-Instruct-V0-report")] | sort_by(.created_at) | last | .workflow_run.id') + echo "runid=$RUN_ID" >> "$GITHUB_OUTPUT" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Query artifact run id for Qwen3-8B V0 latest artifact + id: get_Qwen3_8B_latest_run_id_V0 + run: | + ARTIFACT_JSON=$(gh api "repos/${{ github.repository }}/actions/artifacts") + RUN_ID=$(echo "$ARTIFACT_JSON" | \ + jq -r '[.artifacts[] | select(.name=="${{ github.event.inputs.vllm-ascend-version }}-Qwen3-8B-V0-report")] | sort_by(.created_at) | last | .workflow_run.id') + echo "runid=$RUN_ID" >> "$GITHUB_OUTPUT" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Download Qwen/Qwen2.5-VL-7B-Instruct V0 Artifact + uses: actions/download-artifact@v4 + with: + name: ${{ github.event.inputs.vllm-ascend-version }}-Qwen2.5-VL-7B-Instruct-V0-report + path: ./docs/source/developer_guide/evaluation/accuracy_report + github-token: ${{ secrets.GITHUB_TOKEN }} + repository: vllm-project/vllm-ascend + run-id: ${{ steps.get_Qwen2_5_VL_7B_Instruct_latest_run_id_V0.outputs.runid }} + + - name: Download Qwen/Qwen2.5-7B-Instruct Artifact + uses: actions/download-artifact@v4 + with: + name: ${{ github.event.inputs.vllm-ascend-version }}-Qwen2.5-7B-Instruct-V0-report + path: ./docs/source/developer_guide/evaluation/accuracy_report + github-token: ${{ secrets.GITHUB_TOKEN }} + repository: vllm-project/vllm-ascend + run-id: ${{ steps.get_Qwen2_5_7B_Instruct_latest_run_id_V0.outputs.runid }} + + - name: Download meta-llama/Llama-3.1-8B-Instruct Artifact + uses: actions/download-artifact@v4 + with: + name: ${{ github.event.inputs.vllm-ascend-version }}-Llama-3.1-8B-Instruct-V0-report + path: ./docs/source/developer_guide/evaluation/accuracy_report + github-token: ${{ secrets.GITHUB_TOKEN }} + repository: vllm-project/vllm-ascend + run-id: ${{ steps.get_Llama_3_1_8B_Instruct_latest_run_id_V0.outputs.runid }} + + - name: Download Qwen/Qwen3-8B Artifact + uses: actions/download-artifact@v4 + with: + name: ${{ github.event.inputs.vllm-ascend-version }}-Qwen3-8B-V0-report + path: ./docs/source/developer_guide/evaluation/accuracy_report + github-token: ${{ secrets.GITHUB_TOKEN }} + repository: vllm-project/vllm-ascend + run-id: ${{ steps.get_Qwen3_8B_latest_run_id_V0.outputs.runid }} + + - name: Display Files + working-directory: ./docs/source/developer_guide/evaluation/accuracy_report + run: | + cat ./Qwen2.5-VL-7B-Instruct.md + cat ./Llama-3.1-8B-Instruct.md + cat ./Qwen2.5-7B-Instruct.md + cat ./Qwen3-8B.md + + - name: Create Pull Request for markdown update + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.PR_TOKEN }} + base: ${{ github.ref_name }} + branch: auto-pr/accuracy-test + commit-message: "Update accuracy report for ${{ github.event.inputs.branch }}" + add-paths: ./docs/source/developer_guide/evaluation/accuracy_report/*.md + title: "[Doc]Update accuracy report for ${{ github.event.inputs.branch }}" + body: | + The accuracy results running on Ascend NPU have changed, I'm updating the report. + Please review the changes. + + - [Workflow run][1] + - [Qwen2.5-7B-Instruct accuracy report][2] + - [Llama-3.1-8B-Instruct accuracy report][3] + - [Qwen2.5-VL-7B-Instruct accuracy report][4] + - [Qwen3-8B accuracy report][5] + + [1]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + [2]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ steps.get_Qwen2_5_7B_Instruct_latest_run_id_V0.outputs.runid }} + [3]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ steps.get_Llama_3_1_8B_Instruct_latest_run_id_V0.outputs.runid }} + [4]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ steps.get_Qwen2_5_VL_7B_Instruct_latest_run_id_V0.outputs.runid }} + [5]: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ steps.get_Qwen3_8B_latest_run_id_V0.outputs.runid }} \ No newline at end of file diff --git a/.github/workflows/accuracy_test.yaml b/.github/workflows/accuracy_test.yaml new file mode 100644 index 00000000000..33f3be7fdd0 --- /dev/null +++ b/.github/workflows/accuracy_test.yaml @@ -0,0 +1,203 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +name: Accuracy Tests + +on: + workflow_dispatch: + inputs: + vllm-version: + description: 'what vllm version to accuracy test?' + required: true + type: string + vllm-ascend-version: + description: 'what vllm-ascend version to accuracy test?' + required: true + type: string + models: + description: 'choose model(all/Qwen2.5-7B-Instruct/Llama-3.1-8B-Instruct/Qwen2.5-VL-7B-Instruct/Qwen3-8B)' + required: true + type: choice + options: + - all + - Qwen/Qwen2.5-7B-Instruct + - meta-llama/Llama-3.1-8B-Instruct + - Qwen/Qwen2.5-VL-7B-Instruct + - Qwen/Qwen3-8B + default: 'all' + +# Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly +# declared as "shell: bash -el {0}" on steps that need to be properly activated. +# It's used to activate ascend-toolkit environment variables. +defaults: + run: + shell: bash -el {0} + +jobs: + model_tests: + name: Model Test - ${{ matrix.model_name }} + runs-on: 'linux-arm64-npu-2' + strategy: + matrix: + include: ${{ fromJSON( + (github.event.inputs.models == 'all' && '[{"model_name":"Qwen/Qwen2.5-7B-Instruct","output_file":"Qwen2.5-7B-Instruct"},{"model_name":"meta-llama/Llama-3.1-8B-Instruct","output_file":"Llama-3.1-8B-Instruct"},{"model_name":"Qwen/Qwen2.5-VL-7B-Instruct","output_file":"Qwen2.5-VL-7B-Instruct"}, {"model_name":"Qwen/Qwen3-8B","output_file":"Qwen3-8B"}]') || + (github.event.inputs.models == 'Qwen/Qwen2.5-7B-Instruct' && '[{"model_name":"Qwen/Qwen2.5-7B-Instruct","output_file":"Qwen2.5-7B-Instruct"}]') || + (github.event.inputs.models == 'meta-llama/Llama-3.1-8B-Instruct' && '[{"model_name":"meta-llama/Llama-3.1-8B-Instruct","output_file":"Llama-3.1-8B-Instruct"}]') || + (github.event.inputs.models == 'Qwen/Qwen2.5-VL-7B-Instruct' && '[{"model_name":"Qwen/Qwen2.5-VL-7B-Instruct","output_file":"Qwen2.5-VL-7B-Instruct"}]') || + (github.event.inputs.models == 'Qwen/Qwen3-8B' && '[{"model_name":"Qwen/Qwen3-8B","output_file":"Qwen3-8B"}]') + ) }} + fail-fast: false + + container: + image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10 + env: + HF_ENDPOINT: https://hf-mirror.com + HF_TOKEN: ${{ secrets.HF_TOKEN }} + DATASET_SOURCE: ModelScope + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Check npu and CANN info + run: | + npu-smi info + cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info + + - name: Config mirrors + run: | + sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + apt-get update -y + apt install git -y + git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf https://github.com/ + + - name: Install system dependencies + run: | + apt-get -y install `cat packages.txt` + apt-get -y install gcc g++ cmake libnuma-dev + + + - name: Install system dependencies + run: | + apt-get -y install `cat packages.txt` + apt-get -y install gcc g++ cmake libnuma-dev + + - name: Checkout vllm-project/vllm repo + uses: actions/checkout@v4 + with: + repository: vllm-project/vllm + path: ./vllm-empty + ref: ${{ github.event.inputs.vllm-version }} + + - name: Install vllm-project/vllm from source + working-directory: ./vllm-empty + run: VLLM_TARGET_DEVICE=empty pip install -e . + + + - name: Checkout vllm-project/vllm-ascend repo + uses: actions/checkout@v4 + with: + repository: vllm-project/vllm-ascend + path: ./vllm-ascend + ref: ${{ github.event.inputs.vllm-ascend-version }} + fetch-depth: 0 + + - name: Install pta + run: | + if [ ! -d /root/.cache/pta ]; then + mkdir -p /root/.cache/pta + fi + if [ ! -f /root/.cache/pta/torch_npu-2.5.1.dev20250320-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl ]; then + cd /root/.cache/pta + rm -rf pytorch_v2.5.1_py310* + wget https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.5.1/20250320.3/pytorch_v2.5.1_py310.tar.gz + tar -zxvf pytorch_v2.5.1_py310.tar.gz + fi + pip install /root/.cache/pta/torch_npu-2.5.1.dev20250320-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl + + - name: Install vllm-project/vllm-ascend + working-directory: ./vllm-ascend + run: | + pip install -r requirements-dev.txt + pip install -e . + + - name: Checkout EleutherAI/lm-evaluation-harness repo + uses: actions/checkout@v4 + with: + repository: EleutherAI/lm-evaluation-harness + path: ./lm-eval + fetch-depth: 0 + + - name: Install EleutherAI/lm-evaluation-harness + working-directory: ./lm-eval + run: | + pip install -e . + pip install ray datasets==2.16.0 transformers==4.50.3 huggingface-hub==0.29.3 + + - name: Collect version info + run: | + for dir in /usr/local/Ascend/ascend-toolkit/*; do + dname=$(basename "$dir") + if [ "$dname" != "latest" ]; then + TOOLKIT_DIR="$dname" + break + fi + done + INFO_FILE="/usr/local/Ascend/ascend-toolkit/${TOOLKIT_DIR}/$(uname -i)-linux/ascend_toolkit_install.info" + CANN_VERSION=$(grep "version=" "$INFO_FILE" \ + | head -n1 \ + | cut -d'=' -f2 \ + | tr -d '"') + { + echo "CANN_VERSION=$CANN_VERSION" + pip show torch | grep "Version:" | awk '{print "TORCH_VERSION="$2}' + pip show torch_npu | grep "Version:" | awk '{print "TORCH_NPU_VERSION="$2}' + pip show vllm | grep "Version:" | awk '{print "VLLM_VERSION="$2}' | sed 's/+.*//' + } >> "$GITHUB_ENV" + + - name: Print versions + run: | + echo "CANN: ${{ env.CANN_VERSION }}" + echo "Torch NPU: ${{ env.TORCH_NPU_VERSION }}" + echo "Torch: ${{ env.TORCH_VERSION }}" + echo "vLLM: ${{ env.VLLM_VERSION }}" + + - name: Run Accuracy Test for V0 + working-directory: ./benchmarks + env: + VLLM_USE_V1: 0 + PYTORCH_NPU_ALLOC_CONF: max_split_size_mb:256 + run: | + mkdir -p ./accuracy/V0 + python ./scripts/run_accuracy.py \ + --model "${{ matrix.model_name }}" \ + --output "./accuracy/V0/${{ matrix.output_file }}.md" \ + --vllm_ascend_version "${{ github.event.inputs.vllm-ascend-version }}" \ + --cann_version "${{ env.CANN_VERSION }}" \ + --torch_npu_version "${{ env.TORCH_NPU_VERSION }}" \ + --torch_version "${{ env.TORCH_VERSION }}" \ + --vllm_version "${{ env.VLLM_VERSION }}" + + - name: Upload Report for V0 + uses: actions/upload-artifact@v4 + with: + name: "${{ github.event.inputs.vllm-ascend-version }}-${{ matrix.output_file }}-V0-report" + path: ./benchmarks/accuracy/V0/${{ matrix.output_file }}.md + if-no-files-found: warn + retention-days: 90 + overwrite: true diff --git a/benchmarks/scripts/run_accuracy.py b/benchmarks/scripts/run_accuracy.py new file mode 100644 index 00000000000..18579d64ec0 --- /dev/null +++ b/benchmarks/scripts/run_accuracy.py @@ -0,0 +1,231 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import argparse +import gc +import json +import multiprocessing +import sys +from multiprocessing import Queue + +import lm_eval +import torch + +UNIMODAL_MODEL_NAME = [ + "Qwen/Qwen2.5-7B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", + "Qwen/Qwen3-8B" +] +UNIMODAL_TASK = ["ceval-valid", "mmlu", "gsm8k"] +MULTIMODAL_NAME = ["Qwen/Qwen2.5-VL-7B-Instruct"] +MULTIMODAL_TASK = ["mmmu_val"] + +batch_size_dict = {"ceval-valid": 1, "mmlu": 1, "gsm8k": "auto", "mmmu_val": 1} + +MODEL_RUN_INFO = { + "Qwen/Qwen2.5-7B-Instruct": + ("export MODEL_AEGS='{model}, max_model_len=4096,dtype=auto,tensor_parallel_size=2,gpu_memory_utilization=0.6'\n" + "lm_eval --model vllm --modlel_args $MODEL_ARGS --tasks {datasets} \ \n" + "--apply_chat_template --fewshot_as_multiturn --num_fewshot 5 --batch_size 1" + ), + "LLM-Research/Meta-Llama-3.1-8B-Instruct": + ("export MODEL_AEGS='{model}, max_model_len=4096,dtype=auto,tensor_parallel_size=2,gpu_memory_utilization=0.6'\n" + "lm_eval --model vllm --modlel_args $MODEL_ARGS --tasks {datasets} \ \n" + "--apply_chat_template --fewshot_as_multiturn --num_fewshot 5 --batch_size 1" + ), + "Qwen/Qwen3-8B": + ("export MODEL_AEGS='{model}, max_model_len=4096,dtype=auto,tensor_parallel_size=2,gpu_memory_utilization=0.6'\n" + "lm_eval --model vllm --modlel_args $MODEL_ARGS --tasks {datasets} \ \n" + "--apply_chat_template --fewshot_as_multiturn --num_fewshot 5 --batch_size 1" + ), + "Qwen/Qwen2.5-VL-7B-Instruct": + ("export MODEL_AEGS='{model}, max_model_len=8192,dtype=auto,tensor_parallel_size=2,max_images=2'\n" + "lm_eval --model vllm-vlm --modlel_args $MODEL_ARGS --tasks {datasets} \ \n" + "--apply_chat_template --fewshot_as_multiturn --batch_size 1"), +} + + +def run_accuracy_unimodal(queue, model, dataset): + try: + model_args = f"pretrained={model},max_model_len=4096,dtype=auto,tensor_parallel_size=2,gpu_memory_utilization=0.6" + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=dataset, + apply_chat_template=True, + fewshot_as_multiturn=True, + batch_size=batch_size_dict[dataset], + num_fewshot=5, + ) + print(f"Success: {model} on {dataset}") + measured_value = results["results"] + queue.put(measured_value) + except Exception as e: + print(f"Error in run_accuracy_unimodal: {e}") + queue.put(e) + sys.exit(1) + finally: + torch.npu.empty_cache() + gc.collect() + + +def run_accuracy_multimodal(queue, model, dataset): + try: + model_args = f"pretrained={model},max_model_len=8192,dtype=auto,tensor_parallel_size=2,max_images=2" + results = lm_eval.simple_evaluate( + model="vllm-vlm", + model_args=model_args, + tasks=dataset, + apply_chat_template=True, + fewshot_as_multiturn=True, + batch_size=batch_size_dict[dataset], + ) + print(f"Success: {model} on {dataset}") + measured_value = results["results"] + queue.put(measured_value) + except Exception as e: + print(f"Error in run_accuracy_multimodal: {e}") + queue.put(e) + sys.exit(1) + finally: + torch.npu.empty_cache() + gc.collect() + + +def generate_md(model_name, tasks_list, args, datasets): + run_cmd = MODEL_RUN_INFO[model_name].format(model=model_name, + datasets=datasets) + model = model_name.split("/")[1] + preamble = f"""# {model} Accuracy Test +