Skip to content

Commit fe1f1b4

Browse files
committed
fix lint
1 parent 03aed97 commit fe1f1b4

38 files changed

+11947
-11949
lines changed
Lines changed: 123 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,123 @@
1-
#include <iostream>
2-
#include <string>
3-
#include "acl/acl.h"
4-
#include "kernel_tiling/kernel_tiling.h"
5-
#include "tiling/platform/platform_ascendc.h"
6-
#include "tiling/tiling_data.h"
7-
#include "common_tiling.h"
8-
9-
10-
namespace bmm_trans {
11-
using namespace pp_matmul;
12-
13-
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
14-
{"per_channel_symm", 0},
15-
{"per_channel_asymm", 1},
16-
{"per_token_symm", 2},
17-
};
18-
19-
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
20-
{"ND", 0},
21-
{"NZ", 1},
22-
};
23-
24-
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
25-
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
26-
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
27-
28-
// batch size -> memory index
29-
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
30-
31-
template <typename MapType>
32-
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
33-
const char *mode_name)
34-
{
35-
std::string modeStr(mode_name);
36-
c10::string_view mode_str = mode_opt.value_or(default_mode);
37-
auto it = mode_map.find(mode_str);
38-
// if input mode is unsupported, use default value
39-
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
40-
return it->second;
41-
}
42-
43-
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
44-
c10::optional<c10::string_view> format_mode,
45-
c10::optional<c10::string_view> quant_mode)
46-
{
47-
auto tensorAShape = tensor_a.sizes();
48-
auto tensorBShape = tensor_b.sizes();
49-
auto tensorCShape = tensor_c.sizes();
50-
uint32_t n;
51-
uint32_t block_dim;
52-
53-
//auto &platform = PlatformInfo::Instance();
54-
HardwareInfo hwInfo;
55-
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
56-
57-
at::ScalarType aType = tensor_a.scalar_type();
58-
at::ScalarType bType = tensor_b.scalar_type();
59-
at::ScalarType cType = tensor_c.scalar_type();
60-
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
61-
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
62-
"tensor type only support half or bf16");
63-
64-
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
65-
MatMul::QuantMode quantMode =
66-
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
67-
68-
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
69-
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
70-
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
71-
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
72-
n = tensorBShape[2];
73-
} else {
74-
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
75-
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
76-
n = tensorBShape[1] * tensorBShape[3];
77-
}
78-
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
79-
80-
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
81-
.m = static_cast<uint32_t>(tensorAShape[0]),
82-
.k = static_cast<uint32_t>(tensorAShape[2]),
83-
.n = n};
84-
pp_matmul::PpMatmulTilingData matmulTilingData = {
85-
.opShape = opShape,
86-
};
87-
auto dType = atType2tensorDType[aType];
88-
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
89-
.m = opShape.m,
90-
.k = opShape.k,
91-
.n = opShape.n,
92-
.dtypeA = dType,
93-
.dtypeB = dType,
94-
.dtypeC = dType,
95-
.formatB = formatMode,
96-
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
97-
.inDtype = dTypeMap[aType],
98-
.outDtype = dTypeMap[cType],
99-
.quantMode = quantMode};
100-
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
101-
host_utils::PpMatmulTilingCheck(matmulTilingData);
102-
103-
// tiling
104-
int32_t batchIdx = opShape.m - 1;
105-
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
106-
static auto global_tiling_data = at::empty(
107-
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
108-
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
109-
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
110-
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
111-
} else {
112-
// Handle the case where batchIdx is out of range
113-
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
114-
}
115-
at::Tensor tiling_tensor =
116-
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
117-
118-
return std::make_tuple(tiling_tensor, block_dim);
119-
120-
}
121-
122-
}
123-
1+
#include <iostream>
2+
#include <string>
3+
#include "acl/acl.h"
4+
#include "kernel_tiling/kernel_tiling.h"
5+
#include "tiling/platform/platform_ascendc.h"
6+
#include "tiling/tiling_data.h"
7+
#include "common_tiling.h"
8+
9+
10+
namespace bmm_trans {
11+
using namespace pp_matmul;
12+
13+
std::unordered_map<c10::string_view, uint16_t> quantModeMap = {
14+
{"per_channel_symm", 0},
15+
{"per_channel_asymm", 1},
16+
{"per_token_symm", 2},
17+
};
18+
19+
std::unordered_map<c10::string_view, uint16_t> formatModeMap = {
20+
{"ND", 0},
21+
{"NZ", 1},
22+
};
23+
24+
std::unordered_map<c10::ScalarType, TensorDType> atType2tensorDType = {
25+
{at::ScalarType::BFloat16, TensorDType::TENSOR_DTYPE_BF16},
26+
{at::ScalarType::Half, TensorDType::TENSOR_DTYPE_FLOAT16}};
27+
28+
// batch size -> memory index
29+
constexpr uint32_t MAX_CAPTURE_NUM = 1024;
30+
31+
template <typename MapType>
32+
inline int GetModeVal(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
33+
const char *mode_name)
34+
{
35+
std::string modeStr(mode_name);
36+
c10::string_view mode_str = mode_opt.value_or(default_mode);
37+
auto it = mode_map.find(mode_str);
38+
// if input mode is unsupported, use default value
39+
TORCH_CHECK(it != mode_map.end(), modeStr, c10::str(": Unsupported mode value ", mode_str));
40+
return it->second;
41+
}
42+
43+
std::tuple<at::Tensor, uint32_t> batch_matmul_transpose_tiling(const at::Tensor &tensor_a, const at::Tensor &tensor_b, at::Tensor &tensor_c,
44+
c10::optional<c10::string_view> format_mode,
45+
c10::optional<c10::string_view> quant_mode)
46+
{
47+
auto tensorAShape = tensor_a.sizes();
48+
auto tensorBShape = tensor_b.sizes();
49+
auto tensorCShape = tensor_c.sizes();
50+
uint32_t n;
51+
uint32_t block_dim;
52+
53+
//auto &platform = PlatformInfo::Instance();
54+
HardwareInfo hwInfo;
55+
std::map<c10::ScalarType, float> dTypeMap = {{at::ScalarType::Half, 2.0}, {at::ScalarType::BFloat16, 2.0}};
56+
57+
at::ScalarType aType = tensor_a.scalar_type();
58+
at::ScalarType bType = tensor_b.scalar_type();
59+
at::ScalarType cType = tensor_c.scalar_type();
60+
TORCH_CHECK(aType == bType && bType == cType, "tensor type is not the same");
61+
TORCH_CHECK((aType == at::ScalarType::BFloat16) || (aType == at::ScalarType::Half),
62+
"tensor type only support half or bf16");
63+
64+
TensorFormat formatMode = static_cast<TensorFormat>(GetModeVal(formatModeMap, format_mode, "ND", "format_mode"));
65+
MatMul::QuantMode quantMode =
66+
static_cast<MatMul::QuantMode>(GetModeVal(quantModeMap, quant_mode, "per_channel_symm", "quant_mode"));
67+
68+
TORCH_CHECK(tensorAShape.size() == 3, "batch size is not same between srcTensor and dstTensor");
69+
if (formatMode == TensorFormat::TENSOR_FORMAT_ND) {
70+
TORCH_CHECK(tensorBShape.size() == 3, "tensor shape should be dim3 in ND format");
71+
TORCH_CHECK(tensorAShape[2] == tensorBShape[1], "tensor shape is wrong");
72+
n = tensorBShape[2];
73+
} else {
74+
TORCH_CHECK(tensorBShape.size() == 4, "tensor shape should be dim4 in nz format");
75+
TORCH_CHECK(tensorAShape[2] == tensorBShape[2], "tensor shape is wrong");
76+
n = tensorBShape[1] * tensorBShape[3];
77+
}
78+
TORCH_CHECK(tensorAShape[1] == tensorBShape[0], "tensor shape is wrong");
79+
80+
OpShape opShape = {.batchSize = static_cast<uint32_t>(tensorAShape[1]),
81+
.m = static_cast<uint32_t>(tensorAShape[0]),
82+
.k = static_cast<uint32_t>(tensorAShape[2]),
83+
.n = n};
84+
pp_matmul::PpMatmulTilingData matmulTilingData = {
85+
.opShape = opShape,
86+
};
87+
auto dType = atType2tensorDType[aType];
88+
MatMulInfo mmInfo = {.batchSize = opShape.batchSize,
89+
.m = opShape.m,
90+
.k = opShape.k,
91+
.n = opShape.n,
92+
.dtypeA = dType,
93+
.dtypeB = dType,
94+
.dtypeC = dType,
95+
.formatB = formatMode,
96+
.mmType = MatMul::MatMulType::MATMUL_EIN_SUM,
97+
.inDtype = dTypeMap[aType],
98+
.outDtype = dTypeMap[cType],
99+
.quantMode = quantMode};
100+
GetPpMatmulTiling(mmInfo, hwInfo, block_dim, matmulTilingData);
101+
host_utils::PpMatmulTilingCheck(matmulTilingData);
102+
103+
// tiling
104+
int32_t batchIdx = opShape.m - 1;
105+
uint32_t tilingSize = sizeof(pp_matmul::PpMatmulTilingData);
106+
static auto global_tiling_data = at::empty(
107+
{tilingSize * MAX_CAPTURE_NUM}, at::TensorOptions().dtype(at::kByte).device(tensor_a.options().device()));
108+
if (batchIdx >= 0 && batchIdx < MAX_CAPTURE_NUM) {
109+
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, &matmulTilingData,
110+
tilingSize, ACL_MEMCPY_HOST_TO_DEVICE);
111+
} else {
112+
// Handle the case where batchIdx is out of range
113+
TORCH_CHECK(false, "batchIdx is out of range: ", batchIdx);
114+
}
115+
at::Tensor tiling_tensor =
116+
at::from_blob(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * batchIdx), tilingSize, at::kByte);
117+
118+
return std::make_tuple(tiling_tensor, block_dim);
119+
120+
}
121+
122+
}
123+
Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,57 @@
1-
2-
// Licensed under the BSD 3-Clause License (the "License");
3-
// you may not use this file except in compliance with the License.
4-
// You may obtain a copy of the License at
5-
//
6-
// Unless required by applicable law or agreed to in writing, software
7-
// distributed under the License is distributed on an "AS IS" BASIS,
8-
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
// See the License for the specific language governing permissions and
10-
// limitations under the License.
11-
12-
#ifndef UTILS_COMMON_H
13-
#define UTILS_COMMON_H
14-
15-
namespace host_utils {
16-
17-
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
18-
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
19-
20-
inline uint64_t alinInt64Count(uint64_t count)
21-
{
22-
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
23-
}
24-
25-
inline uint64_t alinInt32Count(uint64_t count)
26-
{
27-
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
28-
}
29-
30-
template <typename T>
31-
inline T CeilDiv(const T dividend, const T divisor)
32-
{
33-
if (divisor == 0) {
34-
return UINT32_MAX;
35-
}
36-
return (dividend + divisor - 1) / divisor;
37-
}
38-
39-
template <typename T>
40-
inline T RoundUp(const T val, const T align = 16)
41-
{
42-
if (align == 0 || val + align - 1 < val) {
43-
return 0;
44-
}
45-
return (val + align - 1) / align * align;
46-
}
47-
48-
template <typename T>
49-
inline T RoundDown(const T val, const T align = 16)
50-
{
51-
if (align == 0) {
52-
return 0;
53-
}
54-
return val / align * align;
55-
}
56-
} // namespace host_utils
57-
#endif // UTILS_COMMON_H
1+
2+
// Licensed under the BSD 3-Clause License (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// Unless required by applicable law or agreed to in writing, software
7+
// distributed under the License is distributed on an "AS IS" BASIS,
8+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
// See the License for the specific language governing permissions and
10+
// limitations under the License.
11+
12+
#ifndef UTILS_COMMON_H
13+
#define UTILS_COMMON_H
14+
15+
namespace host_utils {
16+
17+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT64 = 4;
18+
constexpr uint32_t BLK_SIZE_ALIN_FOR_INT32 = 8;
19+
20+
inline uint64_t alinInt64Count(uint64_t count)
21+
{
22+
return (count + BLK_SIZE_ALIN_FOR_INT64 - 1) / BLK_SIZE_ALIN_FOR_INT64 * BLK_SIZE_ALIN_FOR_INT64;
23+
}
24+
25+
inline uint64_t alinInt32Count(uint64_t count)
26+
{
27+
return (count + BLK_SIZE_ALIN_FOR_INT32 - 1) / BLK_SIZE_ALIN_FOR_INT32 * BLK_SIZE_ALIN_FOR_INT32;
28+
}
29+
30+
template <typename T>
31+
inline T CeilDiv(const T dividend, const T divisor)
32+
{
33+
if (divisor == 0) {
34+
return UINT32_MAX;
35+
}
36+
return (dividend + divisor - 1) / divisor;
37+
}
38+
39+
template <typename T>
40+
inline T RoundUp(const T val, const T align = 16)
41+
{
42+
if (align == 0 || val + align - 1 < val) {
43+
return 0;
44+
}
45+
return (val + align - 1) / align * align;
46+
}
47+
48+
template <typename T>
49+
inline T RoundDown(const T val, const T align = 16)
50+
{
51+
if (align == 0) {
52+
return 0;
53+
}
54+
return val / align * align;
55+
}
56+
} // namespace host_utils
57+
#endif // UTILS_COMMON_H

0 commit comments

Comments
 (0)