Skip to content

Commit e98c3c9

Browse files
h1074112368wangxiyuan
authored andcommitted
mlapo add qdown output (vllm-project#4707)
### What this PR does / why we need it? This PR adds mlapo operation support qdown of output. ### Does this PR introduce _any_ user-facing change? mlapo operation add enable_inner_out of input ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: h1074112368 <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent 2648724 commit e98c3c9

File tree

8 files changed

+3136
-26
lines changed

8 files changed

+3136
-26
lines changed

csrc/mla_preprocess/op_host/mla_preprocess.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ struct OpParam {
127127
int32_t cacheMode;
128128
QuantMode quantMode;
129129
caffe2::TypeMeta inDtype;
130+
bool enableInnerOut;
130131
};
131132

132133
class PpMatmulTilingApi
@@ -540,7 +541,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
540541

541542
void MlaPreprocessTiling::SetTilingKey()
542543
{
543-
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
544+
uint64_t tilingKey = (static_cast<uint64_t>(opParam.enableInnerOut)) << 9;
545+
tilingKey |= (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
544546

545547
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
546548
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
@@ -619,21 +621,12 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
619621
return it->second;
620622
}
621623

622-
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
623-
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
624-
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
625-
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
626-
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
627-
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
628-
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
629-
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
630-
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
631-
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
632624
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
633625
const at::Tensor &hiddenState,
634626
const at::Tensor &wuk,
635627
c10::optional<c10::string_view> cache_mode,
636-
c10::optional<c10::string_view> quant_mode
628+
c10::optional<c10::string_view> quant_mode,
629+
bool enable_inner_out
637630
)
638631
{
639632
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
@@ -661,6 +654,7 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
661654
opParam.cacheMode = static_cast<int32_t>(cacheMode);
662655
opParam.quantMode = static_cast<QuantMode>(quantMode);
663656
opParam.inDtype = hiddenState.options().dtype();
657+
opParam.enableInnerOut = enable_inner_out;
664658

665659
MlaTilingData tilingData;
666660
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

csrc/mla_preprocess/op_kernel/mla_preprocess.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
103103
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
104104
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
105105
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
106+
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
107+
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
108+
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;
106109

107110
enum class QuantMode : int32_t {
108111
PER_TENSOR_ASYMM_QUANT = 0,

csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mla_preprocess_mix_fp16.hpp"
1717
#include "mla_preprocess_mix_bf16.hpp"
18+
#include "mla_preprocess_mix_bf16_qdown.hpp"
1819

1920
#include "../op_host/tiling/mla_preprocess_tiling.h"
2021

@@ -23,7 +24,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
2324
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
2425
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
2526
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
26-
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
27+
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR innerOut, GM_ADDR workspace, GM_ADDR tiling)
2728
{
2829
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
2930
PRELOAD(2);
@@ -218,6 +219,54 @@ extern "C" __global__ __aicore__ void mla_preprocess(
218219
}
219220
break;
220221
}
222+
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
223+
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
224+
QuantMode::PER_TENSOR_ASYMM_QUANT>
225+
opBf16Cm0Qm0Inner(mlaTilingData, tiling);
226+
opBf16Cm0Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
227+
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
228+
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
229+
s1, s2, s3, s4, s5, innerOut);
230+
if ASCEND_IS_AIC {
231+
opBf16Cm0Qm0Inner.ProcessCube();
232+
}
233+
if ASCEND_IS_AIV {
234+
opBf16Cm0Qm0Inner.ProcessVector();
235+
}
236+
break;
237+
}
238+
case KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER: {
239+
MLAPO_BF16_INNER::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
240+
QuantMode::PER_TENSOR_ASYMM_QUANT>
241+
opBf16Cm1Qm0Inner(mlaTilingData, tiling);
242+
opBf16Cm1Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
243+
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
244+
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
245+
s1, s2, s3, s4, s5, innerOut);
246+
if ASCEND_IS_AIC {
247+
opBf16Cm1Qm0Inner.ProcessCube();
248+
}
249+
if ASCEND_IS_AIV {
250+
opBf16Cm1Qm0Inner.ProcessVector();
251+
}
252+
break;
253+
}
254+
case KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER: {
255+
MLAPO_BF16_INNER::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
256+
QuantMode::PER_TENSOR_ASYMM_QUANT>
257+
opBf16Cm3Qm0Inner(mlaTilingData, tiling);
258+
opBf16Cm3Qm0Inner.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
259+
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
260+
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
261+
s1, s2, s3, s4, s5, innerOut);
262+
if ASCEND_IS_AIC {
263+
opBf16Cm3Qm0Inner.ProcessCube();
264+
}
265+
if ASCEND_IS_AIV {
266+
opBf16Cm3Qm0Inner.ProcessVector();
267+
}
268+
break;
269+
}
221270
default: {
222271
break;
223272
}
@@ -256,6 +305,7 @@ extern void mla_preprocess_impl(
256305
void* keycache_out,
257306
void* q2,
258307
void* keycache_out2,
308+
void* inner_out,
259309
void* workspace,
260310
void* tiling,
261311
const uint32_t block_dim)
@@ -288,6 +338,7 @@ extern void mla_preprocess_impl(
288338
keycache_out,
289339
q2,
290340
keycache_out2,
341+
inner_out,
291342
workspace,
292343
tiling);
293344
}

0 commit comments

Comments
 (0)