Skip to content

Commit be7b444

Browse files
committed
[Kernel] Add custom aclnn op DispatchGmmCombineDecode
Signed-off-by: wangqiankun <[email protected]>
1 parent b32ef53 commit be7b444

28 files changed

+7849
-26
lines changed

.github/workflows/vllm_ascend_test_nightly_a3.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,21 @@ jobs:
138138
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
139139
tests: ${{ matrix.test_config.tests }}
140140
name: ${{ matrix.test_config.name }}
141+
custom-ops-tests:
142+
name: test ops
143+
if: always() && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch')
144+
needs: multi-node-tests
145+
strategy:
146+
fail-fast: false
147+
matrix:
148+
test_config:
149+
- name: custom-op-dispatch_gmm_combine_decode
150+
os: linux-aarch64-a3-16
151+
tests: tests/e2e/nightly/multicard_ops/test_dispatch_gmm_combine_decode.py
152+
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
153+
with:
154+
runner: ${{ matrix.test_config.os }}
155+
input: 0.12.0
156+
image: 'swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/vllm-ascend:nightly-a3'
157+
tests: ${{ matrix.test_config.tests }}
158+
name: ${{ matrix.test_config.name }}

csrc/build_aclnn.sh

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
ROOT_DIR=$1
44
SOC_VERSION=$2
55

6-
git config --global --add safe.directory "$ROOT_DIR"
7-
86
if [[ "$SOC_VERSION" =~ ^ascend310 ]]; then
97
# ASCEND310P series
108
# currently, no custom aclnn ops for ASCEND310 series
@@ -17,37 +15,44 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
1715
SOC_ARG="ascend910b"
1816
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
1917
# ASCEND910C (A3) series
20-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine"
18+
# depdendency: catlass
19+
git config --global --add safe.directory "$ROOT_DIR"
20+
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass/include
21+
if [[ ! -d "${CATLASS_PATH}" ]]; then
22+
echo "depdendency catlass is missing, try to fetch it..."
23+
if ! git submodule update --init --recursive; then
24+
echo "fetch failed"
25+
exit 1
26+
fi
27+
fi
28+
# depdendency: cann-toolkit file moe_distribute_base.h
29+
HCCL_STRUCT_FILE_PATH=$(find -L "${ASCEND_TOOLKIT_HOME}" -name "moe_distribute_base.h" 2>/dev/null | head -n1)
30+
if [ -z "$HCCL_STRUCT_FILE_PATH" ]; then
31+
echo "cannot find moe_distribute_base.h file in CANN env"
32+
exit 1
33+
fi
34+
# for dispatch_gmm_combine_decode
35+
yes | cp "${HCCL_STRUCT_FILE_PATH}" "${ROOT_DIR}/csrc/dispatch_gmm_combine_decode/op_kernel"
36+
# for dispatch_ffn_combine
37+
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
38+
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
39+
TARGET_FILE="$TARGET_DIR/$(basename "$HCCL_STRUCT_FILE_PATH")"
40+
41+
echo "*************************************"
42+
echo $HCCL_STRUCT_FILE_PATH
43+
echo "$TARGET_DIR"
44+
cp "$HCCL_STRUCT_FILE_PATH" "$TARGET_DIR"
45+
46+
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
47+
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
48+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;lightning_indexer;sparse_flash_attention;dispatch_ffn_combine;dispatch_gmm_combine_decode;"
2149
SOC_ARG="ascend910_93"
2250
else
2351
# others
2452
# currently, no custom aclnn ops for other series
2553
exit 0
2654
fi
2755

28-
git submodule init
29-
git submodule update
30-
31-
32-
# For the compatibility of CANN8.5 and CANN8.3: copy and modify moe_distribute_base.h
33-
file_path=$(find /usr/local/Ascend/ascend-toolkit -name "moe_distribute_base.h" 2>/dev/null | head -n1)
34-
if [ -z "$file_path" ]; then
35-
echo "cannot find moe_distribute_base.h file in CANN env"
36-
exit 1
37-
fi
38-
39-
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
40-
TARGET_DIR="$SCRIPT_DIR/dispatch_ffn_combine/op_kernel/utils/"
41-
TARGET_FILE="$TARGET_DIR/$(basename "$file_path")"
42-
43-
echo "*************************************"
44-
echo $file_path
45-
echo "$TARGET_DIR"
46-
cp "$file_path" "$TARGET_DIR"
47-
48-
sed -i 's/struct HcclOpResParam {/struct HcclOpResParamCustom {/g' "$TARGET_FILE"
49-
sed -i 's/struct HcclRankRelationResV2 {/struct HcclRankRelationResV2Custom {/g' "$TARGET_FILE"
50-
5156

5257
# build custom ops
5358
cd csrc
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
2+
# This file is a part of the CANN Open Software.
3+
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
4+
# Please refer to the License for details. You may not use this file except in compliance with the License.
5+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
6+
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
7+
# See LICENSE in the root of the software repository for the full text of the License.
8+
# ======================================================================================================================
9+
10+
set(_DISPATCH_GMM_INC_OPTS)
11+
if (EXISTS ${CMAKE_SOURCE_DIR}/third_party/catlass/include)
12+
list(APPEND _DISPATCH_GMM_INC_OPTS -I${CMAKE_SOURCE_DIR}/third_party/catlass/include)
13+
else()
14+
message(FATAL_ERROR "dependency catlass is missing, you can fetch it by running 'git submodule update --init --recursive'")
15+
endif()
16+
17+
add_ops_compile_options(
18+
OP_NAME DispatchGmmCombineDecode
19+
OPTIONS --cce-auto-sync=off
20+
-Wno-deprecated-declarations
21+
-Werror
22+
${_DISPATCH_GMM_INC_OPTS}
23+
)
24+
25+
target_sources(op_host_aclnnInner PRIVATE
26+
dispatch_gmm_combine_decode_def.cpp
27+
)
28+
29+
target_sources(opapi PRIVATE
30+
aclnn_dispatch_gmm_combine_decode.cpp
31+
)
32+
33+
if (NOT BUILD_OPEN_PROJECT)
34+
target_sources(aclnn_ops_train PRIVATE
35+
aclnn_dispatch_gmm_combine_decode.cpp
36+
)
37+
38+
target_sources(aclnn_ops_infer PRIVATE
39+
aclnn_dispatch_gmm_combine_decode.cpp
40+
)
41+
endif ()
42+
43+
target_sources(optiling PRIVATE
44+
dispatch_gmm_combine_decode_tiling.cpp
45+
)
46+
47+
target_include_directories(optiling PRIVATE
48+
${CMAKE_CURRENT_SOURCE_DIR}
49+
)
50+
51+
target_sources(opsproto PRIVATE
52+
dispatch_gmm_combine_decode_proto.cpp
53+
)
54+
55+
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
56+
57+
install(FILES ${_GMM_Aclnn_header}
58+
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
59+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#include <string.h>
11+
#include "graph/types.h"
12+
#include "aclnn/opdev/platform.h"
13+
#include "aclnn_dispatch_gmm_combine_decode.h"
14+
15+
enum NnopbaseHcclServerType {
16+
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
17+
NNOPBASE_HCCL_SERVER_TYPE_MTE,
18+
NNOPBASE_HCCL_SERVER_TYPE_END
19+
};
20+
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
21+
22+
#ifdef __cplusplus
23+
extern "C" {
24+
#endif
25+
26+
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
27+
const aclTensor *x,
28+
const aclTensor *expertIds,
29+
const aclTensor *gmm1PermutedWeight,
30+
const aclTensor *gmm1PermutedWeightScale,
31+
const aclTensor *gmm2Weight,
32+
const aclTensor *gmm2WeightScale,
33+
const aclTensor *expertSmoothScalesOptional,
34+
const aclTensor *expertScalesOptional,
35+
char *groupEp,
36+
int64_t epRankSize,
37+
int64_t epRankId,
38+
int64_t moeExpertNum,
39+
int64_t shareExpertNum,
40+
int64_t shareExpertRankNum,
41+
int64_t quantMode,
42+
int64_t globalBs,
43+
const aclTensor *output,
44+
const aclTensor *epRecvCount,
45+
uint64_t *workspaceSize,
46+
aclOpExecutor **executor);
47+
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
48+
void *workspace,
49+
uint64_t workspaceSize,
50+
aclOpExecutor *executor,
51+
aclrtStream stream);
52+
53+
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
54+
const aclTensor *x,
55+
const aclTensor *expertIds,
56+
const aclTensor *gmm1PermutedWeight,
57+
const aclTensor *gmm1PermutedWeightScale,
58+
const aclTensor *gmm2Weight,
59+
const aclTensor *gmm2WeightScale,
60+
const aclTensor *expertSmoothScalesOptional,
61+
const aclTensor *expertScalesOptional,
62+
char *groupEp,
63+
int64_t epRankSize,
64+
int64_t epRankId,
65+
int64_t moeExpertNum,
66+
int64_t shareExpertNum,
67+
int64_t shareExpertRankNum,
68+
int64_t quantMode,
69+
int64_t globalBs,
70+
const aclTensor *output,
71+
const aclTensor *epRecvCount,
72+
uint64_t *workspaceSize,
73+
aclOpExecutor **executor)
74+
{
75+
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
76+
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
77+
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
78+
output, epRecvCount, workspaceSize, executor);
79+
}
80+
81+
aclnnStatus aclnnDispatchGmmCombineDecode(
82+
void *workspace,
83+
uint64_t workspaceSize,
84+
aclOpExecutor *executor,
85+
aclrtStream stream)
86+
{
87+
if (NnopbaseSetHcclServerType) {
88+
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
89+
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
90+
} else {
91+
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
92+
}
93+
}
94+
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
95+
}
96+
97+
#ifdef __cplusplus
98+
}
99+
#endif
100+
101+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#ifndef DISPATCH_GMM_COMBINE_DECODE
11+
#define DISPATCH_GMM_COMBINE_DECODE
12+
13+
#include "aclnn/acl_meta.h"
14+
15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
19+
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
20+
const aclTensor *x,
21+
const aclTensor *expertIds,
22+
const aclTensor *gmm1PermutedWeight,
23+
const aclTensor *gmm1PermutedWeightScale,
24+
const aclTensor *gmm2Weight,
25+
const aclTensor *gmm2WeightScale,
26+
const aclTensor *expertSmoothScalesOptional,
27+
const aclTensor *expertScalesOptional,
28+
char *groupEp,
29+
int64_t epRankSize,
30+
int64_t epRankId,
31+
int64_t moeExpertNum,
32+
int64_t shareExpertNum,
33+
int64_t shareExpertRankNum,
34+
int64_t quantMode,
35+
int64_t globalBs,
36+
const aclTensor *output,
37+
const aclTensor *epRecvCount,
38+
uint64_t *workspaceSize,
39+
aclOpExecutor **executor);
40+
41+
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
42+
void *workspace,
43+
uint64_t workspaceSize,
44+
aclOpExecutor *executor,
45+
aclrtStream stream);
46+
47+
#ifdef __cplusplus
48+
}
49+
#endif
50+
51+
#endif
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#include "register/op_def_registry.h"
11+
12+
namespace ops {
13+
class DispatchGmmCombineDecode : public OpDef
14+
{
15+
public:
16+
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
17+
{
18+
this->Input("x")
19+
.ParamType(REQUIRED)
20+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
21+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
22+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
23+
this->Input("expert_ids")
24+
.ParamType(REQUIRED)
25+
.DataType({ge::DT_INT32, ge::DT_INT32})
26+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
27+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
28+
this->Input("gmm1_permuted_weight")
29+
.ParamType(REQUIRED)
30+
.DataType({ge::DT_INT8, ge::DT_INT8})
31+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
32+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
33+
this->Input("gmm1_permuted_weight_scale")
34+
.ParamType(REQUIRED)
35+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
36+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
37+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
38+
this->Input("gmm2_weight")
39+
.ParamType(REQUIRED)
40+
.DataType({ge::DT_INT8, ge::DT_INT8})
41+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
42+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
43+
this->Input("gmm2_weight_scale")
44+
.ParamType(REQUIRED)
45+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
46+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
47+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
48+
this->Input("expert_smooth_scales")
49+
.ParamType(OPTIONAL)
50+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
51+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
52+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
53+
this->Input("expert_scales")
54+
.ParamType(OPTIONAL)
55+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
56+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
57+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
58+
this->Output("output")
59+
.ParamType(REQUIRED)
60+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
61+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
62+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
63+
this->Output("ep_recv_count")
64+
.ParamType(REQUIRED)
65+
.DataType({ge::DT_INT32, ge::DT_INT32})
66+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
67+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
68+
this->Attr("group_ep").String();
69+
this->Attr("ep_rank_size").Int();
70+
this->Attr("ep_rank_id").Int();
71+
this->Attr("moe_expert_num").Int();
72+
this->Attr("share_expert_num").Int();
73+
this->Attr("share_expert_rank_num").Int();
74+
this->Attr("quant_mode").Int();
75+
this->Attr("global_bs").Int();
76+
77+
this->MC2().HcclGroup({"group_ep"});
78+
this->AICore().AddConfig("ascend910_93");
79+
}
80+
};
81+
82+
OP_ADD(DispatchGmmCombineDecode);
83+
} // namespace ops

0 commit comments

Comments
 (0)