Skip to content

Commit d412565

Browse files
authored
[Cherry-pick]bmm_transpose to v011dev (#3995)
### What this PR does / why we need it? Add a custom op to acclerater the deepseek model. The fusion ops combine the bmm and transpose together, which is applied to mla module. Cherry-pick from this commtid c68ddc1 ### Does this PR introduce _any_ user-facing change? No --------- Signed-off-by: hust17yixuan <[email protected]>
1 parent 6391f06 commit d412565

File tree

15 files changed

+1736
-13
lines changed

15 files changed

+1736
-13
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ repos:
1212
- id: codespell
1313
args: [
1414
--toml, pyproject.toml,
15-
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
16-
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
15+
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
16+
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND,ND'
1717
]
1818
additional_dependencies:
1919
- tomli

CMakeLists.txt

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,34 @@ include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
5555
file(GLOB KERNEL_FILES
5656
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
5757

58-
ascendc_library(vllm_ascend_kernels SHARED
58+
set(VLLM_ASCEND_CUSTOM_OP
5959
${KERNEL_FILES}
6060
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
62+
)
63+
64+
set(VLLM_ASCEND_CUSTOM_OP_EXCLUDE
65+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_kernel/batch_matmul_transpose_kernel.cpp
66+
)
67+
68+
if(SOC_VERSION STREQUAL "ASCEND310P3")
69+
list(REMOVE_ITEM VLLM_ASCEND_CUSTOM_OP ${VLLM_ASCEND_CUSTOM_OP_EXCLUDE})
70+
endif()
71+
72+
ascendc_library(vllm_ascend_kernels SHARED
73+
${VLLM_ASCEND_CUSTOM_OP}
6174
)
6275

6376
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
6477

65-
file(GLOB VLLM_ASCEND_SRC
66-
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
78+
if(SOC_VERSION STREQUAL "ASCEND310P3")
79+
file(GLOB VLLM_ASCEND_SRC
80+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
81+
else()
82+
file(GLOB VLLM_ASCEND_SRC
83+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp
84+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host/tiling/tiling_data.cpp)
85+
endif()
6786

6887
include_directories(
6988
${pybind11_INCLUDE_DIRS}
@@ -73,6 +92,7 @@ include_directories(
7392
${ASCEND_HOME_PATH}/include
7493
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
7594
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
95+
${CMAKE_CURRENT_SOURCE_DIR}/csrc/batch_matmul_transpose/op_host
7696
)
7797

7898
set(
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include <iostream>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/tiling_data.h"
7+
#include "common_tiling.h"
8+
9+
10+
namespace bmm_trans {
11+
using namespace pp_matmul;
12+
13+
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
14+
{"per_channel_symm", 0},
15+
{"per_channel_asymm", 1},
16+
{"per_token_symm", 2},
17+
};
18+
19+
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
20+
{"ND", 0},
21+
{"NZ", 1},
22+
};
23+
24+
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
25+
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
26+
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
27+
28+
// batch size -> memory index
29+
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
30+
31+
template <typename MapType>
32+
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
33+
const char *mode_name)
34+
{
35+
std::string modeStr(mode_name);
36+
c10::string_view mode_str = mode_opt.value_or(default_mode);
37+
auto it = mode_map.find(mode_str);
38+
// if input mode is unsupported, use default value
39+
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
40+
return it->second;
41+
}
42+
43+
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
44+
c10::optional<c10::string_view> format_mode,
45+
c10::optional<c10::string_view> quant_mode)
46+
{
47+
auto tensorAShape = tensor_a.sizes();
48+
auto tensorBShape = tensor_b.sizes();
49+
auto tensorCShape = tensor_c.sizes();
50+
uint32_t n;
51+
uint32_t block_dim;
52+
53+
//auto &platform = PlatformInfo::Instance();
54+
HardwareInfo hwInfo;
55+
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
56+
57+
at::ScalarType aType = tensor_a.scalar_type();
58+
at::ScalarType bType = tensor_b.scalar_type();
59+
at::ScalarType cType = tensor_c.scalar_type();
60+
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
61+
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
62+
"tensor type only support half or bf16");
63+
64+
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
65+
MatMul::QuantMode quantMode =
66+
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
67+
68+
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
69+
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
70+
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
71+
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
72+
n = tensorBShape[2];
73+
} else {
74+
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
75+
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
76+
n = tensorBShape[1] * tensorBShape[3];
77+
}
78+
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
79+
80+
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
81+
.m = static_cast<uint32_t>(tensorAShape[0]),
82+
.k = static_cast<uint32_t>(tensorAShape[2]),
83+
.n = n};
84+
pp_matmul::PpMatmulTilingData matmulTilingData = {
85+
.opShape = opShape,
86+
};
87+
auto dType = atType2tensorDType[aType];
88+
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
89+
.m = opShape.m,
90+
.k = opShape.k,
91+
.n = opShape.n,
92+
.dtypeA = dType,
93+
.dtypeB = dType,
94+
.dtypeC = dType,
95+
.formatB = formatMode,
96+
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
97+
.inDtype = dTypeMap[aType],
98+
.outDtype = dTypeMap[cType],
99+
.quantMode = quantMode};
100+
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
101+
host_utils::PpMatmulTilingCheck(matmulTilingData);
102+
103+
// tiling
104+
int32_t batchIdx = opShape.m - 1;
105+
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
106+
static auto global_tiling_data = at::empty(
107+
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
108+
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
109+
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
110+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
111+
} else {
112+
// Handle the case where batchIdx is out of range
113+
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
114+
}
115+
at::Tensor tiling_tensor =
116+
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
117+
118+
return std::make_tuple(tiling_tensor, block_dim);
119+
120+
}
121+
122+
}
123+
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
// Licensed under the BSD 3-Clause License (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// Unless required by applicable law or agreed to in writing, software
7+
// distributed under the License is distributed on an "AS IS" BASIS,
8+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
// See the License for the specific language governing permissions and
10+
// limitations under the License.
11+
12+
#ifndef UTILS_COMMON_H
13+
#define UTILS_COMMON_H
14+
15+
namespace host_utils {
16+
17+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
18+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
19+
20+
inline uint64_t alinInt64Count(uint64_t count)
21+
{
22+
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
23+
}
24+
25+
inline uint64_t alinInt32Count(uint64_t count)
26+
{
27+
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
28+
}
29+
30+
template <typename T>
31+
inline T CeilDiv(const T dividend, const T divisor)
32+
{
33+
if (divisor == 0) {
34+
return UINT32_MAX;
35+
}
36+
return (dividend + divisor - 1) / divisor;
37+
}
38+
39+
template <typename T>
40+
inline T RoundUp(const T val, const T align = 16)
41+
{
42+
if (align == 0 || val + align - 1 < val) {
43+
return 0;
44+
}
45+
return (val + align - 1) / align * align;
46+
}
47+
48+
template <typename T>
49+
inline T RoundDown(const T val, const T align = 16)
50+
{
51+
if (align == 0) {
52+
return 0;
53+
}
54+
return val / align * align;
55+
}
56+
} // namespace host_utils
57+
#endif // UTILS_COMMON_H

0 commit comments

Comments
 (0)