Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions csrc/mla_preprocess/op_host/mla_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ struct OpParam {
int32_t cacheMode;
QuantMode quantMode;
caffe2::TypeMeta inDtype;
bool enableInnerOut;
};

class PpMatmulTilingApi
Expand Down Expand Up @@ -540,7 +541,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()

void MlaPreprocessTiling::SetTilingKey()
{
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
uint64_t tilingKey = (static_cast<uint64_t>(opParam.enableInnerOut)) << 9;
tilingKey |= (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;

tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
Expand Down Expand Up @@ -619,21 +621,12 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
return it->second;
}

// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode
c10::optional<c10::string_view> quant_mode,
bool enable_inner_out
)
{
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
Expand Down Expand Up @@ -661,6 +654,7 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
opParam.enableInnerOut = enable_inner_out;

MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);
Expand Down
3 changes: 3 additions & 0 deletions csrc/mla_preprocess/op_kernel/mla_preprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;

enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,
Expand Down
53 changes: 52 additions & 1 deletion csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "mla_preprocess_mix_bf16_qdown.hpp"

#include "../op_host/tiling/mla_preprocess_tiling.h"

Expand All @@ -23,7 +24,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR innerOut, GM_ADDR workspace, GM_ADDR tiling)
{
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
PRELOAD(2);
Expand Down Expand Up @@ -218,6 +219,54 @@ extern "C" __global__ __aicore__ void mla_preprocess(
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0Inner(mlaTilingData, tiling);
opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm0Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm0Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0Inner(mlaTilingData, tiling);
opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm1Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0Inner.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: {
MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0Inner(mlaTilingData, tiling);
opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5, innerOut);
if ASCEND_IS_AIC {
opBf16Cm3Qm0Inner.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm3Qm0Inner.ProcessVector();
}
break;
}
Comment on lines +222 to +269
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication across the new case statements for _INNER keys. The logic inside case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER, case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER, and case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER is identical except for the CacheMode template parameter (0, 1, 3) for MLAPO_BF16_INNER::MLAOperation.

This duplication makes the code harder to read and maintain. Any future changes to this logic will need to be applied in three places, increasing the risk of errors.

To reduce duplication, you could use a helper function templated on the CacheMode to encapsulate the common logic. For example:

template <int CacheMode>
__aicore__ void ProcessInnerOp(MlaTilingData& mlaTilingData, GM_ADDR tiling, /* other params */) {
    MLAPO_BF16_INNER::MLAOperation<__bf16, CacheMode, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
                             QuantMode::PER_TENSOR_ASYMM_QUANT>
        op(mlaTilingData, tiling);
    op.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
                      quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
                      bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
                      s1, s2, s3, s4, s5, innerOut);
    if ASCEND_IS_AIC {
        op.ProcessCube();
    }
    if ASCEND_IS_AIV {
        op.ProcessVector();
    }
}

// Then in the switch statement:
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER:
    ProcessInnerOp<0>(mlaTilingData, tiling, ...);
    break;
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER:
    ProcessInnerOp<1>(mlaTilingData, tiling, ...);
    break;
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER:
    ProcessInnerOp<3>(mlaTilingData, tiling, ...);
    break;

Since adding a new function might be a larger change, you could also use a macro to achieve a similar result within the current function body.

default: {
break;
}
Expand Down Expand Up @@ -256,6 +305,7 @@ extern void mla_preprocess_impl(
void* keycache_out,
void* q2,
void* keycache_out2,
void* inner_out,
void* workspace,
void* tiling,
const uint32_t block_dim)
Expand Down Expand Up @@ -288,6 +338,7 @@ extern void mla_preprocess_impl(
keycache_out,
q2,
keycache_out2,
inner_out,
workspace,
tiling);
}
Expand Down
Loading
Loading