Skip to content

Commit c12eb22

Browse files
[feat] mlapo add bf16 no_quant support (#4852)
### What this PR does / why we need it? This PR adds mlapo operation support for bf16 no_quant mode. ### Does this PR introduce _any_ user-facing change? This PR makes quant related parameters optional. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: chenjunyi <[email protected]>
1 parent c95c271 commit c12eb22

12 files changed

+1510
-81
lines changed

csrc/mla_preprocess/op_host/mla_preprocess.h

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ constexpr uint32_t L1_BIAS_SIZE = 2048;
4343
constexpr uint32_t L0C_SIZE = 128 * 1024;
4444
constexpr uint32_t CONCAT_SIZE = 512;
4545

46-
constexpr uint32_t HIDDEN_STRATE = 7168;
4746
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
4847
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
4948
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
@@ -122,6 +121,8 @@ struct PlatformInfo {
122121
};
123122

124123
struct OpParam {
124+
uint32_t isWeightQuantized;
125+
uint32_t hiddenStateDim;
125126
uint32_t N;
126127
uint32_t headNum;
127128
int32_t cacheMode;
@@ -392,7 +393,7 @@ class MlaPreprocessTiling
392393
void MlaPreprocessTiling::RmsNormQuantTiling()
393394
{
394395
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
395-
tilingData->rmsNumCol1 = HIDDEN_STRATE;
396+
tilingData->rmsNumCol1 = opParam.hiddenStateDim;
396397
tilingData->rmsNumRow1 = opParam.N;
397398
tilingData->rmsQuantMin1 = -CONST_128;
398399
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
@@ -508,9 +509,9 @@ void MlaPreprocessTiling::EinSumQuantTiling()
508509
void MlaPreprocessTiling::SetMlapoWorkSpace()
509510
{
510511
uint64_t s1wsFactor =
511-
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
512+
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(opParam.hiddenStateDim * sizeof(int8_t),
512513
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
513-
: HIDDEN_STRATE * sizeof(int8_t));
514+
: opParam.hiddenStateDim * sizeof(int8_t));
514515
uint64_t workSizeS1 = s1wsFactor;
515516
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
516517
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
@@ -525,7 +526,8 @@ void MlaPreprocessTiling::SetMlapoWorkSpace()
525526
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
526527

527528
uint64_t userWorkspaceSize;
528-
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
529+
if (opParam.isWeightQuantized == 1 &&
530+
(opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
529531
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
530532
} else {
531533
userWorkspaceSize = 3 * maxWorkspaceSize;
@@ -554,21 +556,23 @@ void MlaPreprocessTiling::Init()
554556
{
555557
tilingData->numCore = platformInfo.coreNumAic;
556558
tilingData->n = opParam.N;
557-
559+
tilingData->hiddenStateDim = opParam.hiddenStateDim;
560+
tilingData->isWeightQuantized = opParam.isWeightQuantized;
561+
bool enDequant = (opParam.isWeightQuantized == 1);
558562
bool deqOnTheFly = false;
559-
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
563+
if (enDequant && (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT)) {
560564
deqOnTheFly = true;
561565
}
562566

563567
PpMatmulTilingApi mm1TilingApi(platformInfo,
564-
1, // numBatch
565-
opParam.N, // m
566-
HIDDEN_STRATE, // k
567-
HIDDEN_STRATE_MM, // n
568-
false, // transA
569-
true, // transB
570-
true, // enDequant
571-
deqOnTheFly); // in bf16.cce?
568+
1, // numBatch
569+
opParam.N, // m
570+
opParam.hiddenStateDim, // k
571+
HIDDEN_STRATE_MM, // n
572+
false, // transA
573+
true, // transB
574+
enDequant, // enDequant
575+
deqOnTheFly); // in bf16.cce?
572576
mm1TilingApi.GetTilingData(tilingData->mm1);
573577

574578
PpMatmulTilingApi mm2TilingApi(platformInfo,
@@ -578,7 +582,7 @@ void MlaPreprocessTiling::Init()
578582
opParam.headNum * HIDDEN_STRATE_ROPE, // n
579583
false, // transA
580584
true, // transB
581-
true, // enDequant
585+
enDequant, // enDequant
582586
deqOnTheFly); // in bf16.cce?
583587
mm2TilingApi.GetTilingData(tilingData->mm2);
584588

@@ -609,6 +613,8 @@ std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
609613
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
610614
{"per_tensor_quant_asymm", 0},
611615
{"per_token_quant_symm", 1},
616+
{"per_token_quant_asymm", 2},
617+
{"no_quant", 3}
612618
};
613619

614620
template <typename MapType>
@@ -623,6 +629,7 @@ inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view>
623629

624630
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
625631
const at::Tensor &hiddenState,
632+
const at::Tensor &wdqkv,
626633
const at::Tensor &wuk,
627634
c10::optional<c10::string_view> cache_mode,
628635
c10::optional<c10::string_view> quant_mode,
@@ -647,14 +654,21 @@ std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
647654

648655
int32_t N = hiddenState.sizes()[0];
649656
int32_t headNum = wuk.sizes()[0];
657+
uint32_t hiddenStateDim = hiddenState.sizes().back();
650658

651659
OpParam opParam;
660+
opParam.hiddenStateDim = hiddenStateDim;
652661
opParam.N = N;
653662
opParam.headNum = headNum;
654663
opParam.cacheMode = static_cast<int32_t>(cacheMode);
655664
opParam.quantMode = static_cast<QuantMode>(quantMode);
656665
opParam.inDtype = hiddenState.options().dtype();
657666
opParam.enableInnerOut = enable_inner_out;
667+
if (wdqkv.options().dtype() == at::kBFloat16 || wdqkv.options().dtype() == at::kHalf) {
668+
opParam.isWeightQuantized = 0;
669+
} else {
670+
opParam.isWeightQuantized = 1;
671+
}
658672

659673
MlaTilingData tilingData;
660674
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);

csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ struct MlaTilingData {
9090
uint32_t esqHeadTail{0};
9191
uint32_t esqColLoop{0};
9292
uint32_t esqColTail{0};
93+
94+
// hidden state dimension
95+
uint32_t hiddenStateDim{7168};
96+
97+
uint32_t isWeightQuantized{1};
9398
};
9499

95100
#endif // MLAPREPROCESS_TILING_H

csrc/mla_preprocess/op_kernel/mla_preprocess.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format
4949
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
5050

5151
// pp matmul
52-
constexpr uint32_t HIDDTEN_STATE = 7168;
5352
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
5453
constexpr uint32_t HALF_BLOCK_SIZE = 64;
5554
constexpr uint32_t HALF_VECTOR_SIZE = 64;
@@ -103,6 +102,7 @@ constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
103102
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
104103
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
105104
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
105+
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_3 = 281;
106106
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER = 256 + 512;
107107
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0_INNER = 257 + 512;
108108
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0_INNER = 259 + 512;

csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mla_preprocess_mix_fp16.hpp"
1717
#include "mla_preprocess_mix_bf16.hpp"
1818
#include "mla_preprocess_mix_bf16_qdown.hpp"
19+
#include "mla_preprocess_mix_bf16_nq.hpp"
1920

2021
#include "../op_host/tiling/mla_preprocess_tiling.h"
2122

@@ -42,6 +43,7 @@ extern "C" __global__ __aicore__ void mla_preprocess(
4243

4344
mlaTilingData.tilingKey = tilingData->tilingKey;
4445
mlaTilingData.n = tilingData->n;
46+
mlaTilingData.hiddenStateDim = tilingData->hiddenStateDim;
4547

4648
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
4749
mlaTilingData.mm1.m = tilingData->mm1.m;
@@ -173,12 +175,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
173175
}
174176
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
175177
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
176-
QuantMode::PER_TENSOR_ASYMM_QUANT>
178+
QuantMode::PER_TENSOR_ASYMM_QUANT>
177179
opBf16Cm0Qm0(mlaTilingData, tiling);
178180
opBf16Cm0Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
179-
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
180-
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
181-
s1, s2, s3, s4, s5);
181+
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
182+
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
183+
s1, s2, s3, s4, s5);
182184
if ASCEND_IS_AIC {
183185
opBf16Cm0Qm0.ProcessCube();
184186
}
@@ -189,12 +191,12 @@ extern "C" __global__ __aicore__ void mla_preprocess(
189191
}
190192
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
191193
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
192-
QuantMode::PER_TENSOR_ASYMM_QUANT>
194+
QuantMode::PER_TENSOR_ASYMM_QUANT>
193195
opBf16Cm1Qm0(mlaTilingData, tiling);
194196
opBf16Cm1Qm0.Init(hiddenState, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
195-
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
196-
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
197-
s1, s2, s3, s4, s5);
197+
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
198+
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
199+
s1, s2, s3, s4, s5);
198200
if ASCEND_IS_AIC {
199201
opBf16Cm1Qm0.ProcessCube();
200202
}
@@ -219,6 +221,21 @@ extern "C" __global__ __aicore__ void mla_preprocess(
219221
}
220222
break;
221223
}
224+
case KEY_BF16_CACHEMODE_1_QUANTMODE_3: {
225+
MLAPO_BF16_NQ::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
226+
opBf16Cm1Qm0(mlaTilingData, tiling);
227+
opBf16Cm1Qm0.Init(hiddenState, wdqkv, gamma2, beta2,
228+
gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
229+
wuk, q, keycacheOut, q2, keycacheOut2,
230+
s1, s2, s3);
231+
if ASCEND_IS_AIC {
232+
opBf16Cm1Qm0.ProcessCube();
233+
}
234+
if ASCEND_IS_AIV {
235+
opBf16Cm1Qm0.ProcessVector();
236+
}
237+
break;
238+
}
222239
case KEY_BF16_CACHEMODE_0_QUANTMODE_0_INNER: {
223240
MLAPO_BF16_INNER::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
224241
QuantMode::PER_TENSOR_ASYMM_QUANT>

csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,7 @@ class MLAOperation
23862386
this->num_row = mlaParams_.n;
23872387
this->epsilon_ = 1e-6;
23882388
this->mlaParams = mlaParams_;
2389+
this->hiddenStateDim = mlaParams_.hiddenStateDim;
23892390
}
23902391

23912392
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR quantScale1Gm,
@@ -2692,6 +2693,7 @@ class MLAOperation
26922693
uint32_t blockOffset;
26932694
uint32_t perTaskNum;
26942695
uint32_t resTaskNum;
2696+
uint32_t hiddenStateDim;
26952697
MlaTilingData mlaParams;
26962698

26972699
uint32_t num_core_;
@@ -2795,18 +2797,15 @@ MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, q
27952797
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
27962798
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
27972799
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
2800+
const uint32_t base_offset = hiddenStateDim * 6;
27982801
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
2799-
AscendC::LocalTensor<InDtype> scale_tensor =
2800-
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
2801-
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
2802-
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
2803-
AscendC::LocalTensor<float> res1_tensor =
2804-
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
2802+
AscendC::LocalTensor<InDtype> scale_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(base_offset);
2803+
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(base_offset + 32);
2804+
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(base_offset + 64);
28052805
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
2806-
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
2806+
base_offset + 64 + num_col_align_f32 * 4);
28072807
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
2808-
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
2809-
BUF_FACTOR * num_col_align_f32 * 4 + 64);
2808+
base_offset + 64 + num_col_align_f32 * 4 + BUF_FACTOR * num_col_align_f32 * 4 + 64);
28102809
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
28112810
}
28122811
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);

0 commit comments

Comments
 (0)