Skip to content

Commit bec381e

Browse files
authored
[ascend]Optimize moe (#203)
1 parent d75e16f commit bec381e

File tree

14 files changed

+231
-160
lines changed

14 files changed

+231
-160
lines changed

dlinfer/framework/lmdeploy_ext/dynamo/graph_mode_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def DeepseekV2Attention_forward(
8787
key_states,
8888
value_states,
8989
past_key_value[0],
90-
past_key_value[0][..., :nope_size],
90+
past_key_value[1],
9191
attn_metadata,
9292
k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
9393
v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],

dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,15 +633,26 @@ def infer_result(self, x, num_experts):
633633
)
634634

635635

636-
class MoeInitRouting(Operator):
636+
class AclNnMoeGatingTopkSoftmax(Operator):
637+
def __init__(self):
638+
super().__init__("AclNnMoeGatingTopkSoftmax")
639+
640+
def infer_result(self, x, topk):
641+
return (
642+
x.new_empty((*x.shape[:-1], topk)),
643+
x.new_empty((*x.shape[:-1], topk), dtype=torch.int32),
644+
)
645+
646+
647+
class AclNnMoeInitRouting(Operator):
637648
def __init__(self):
638649
super().__init__("AclNnMoeInitRouting")
639650

640-
def infer_result(self, x, row_ids, topk_ids, active_num, num_experts):
651+
def infer_result(self, x, topk_ids, num_experts):
641652
return (
642653
x.repeat_interleave(topk_ids.size(1), dim=0),
643-
row_ids.flatten(),
644654
topk_ids.flatten(),
655+
topk_ids.new_empty((num_experts,)),
645656
)
646657

647658

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,17 @@ class PrepareMoeParam:
652652
numExperts: int = 0
653653

654654

655+
@dataclass
656+
class AclNnMoeGatingTopkSoftmaxParam:
657+
name: str = ""
658+
topk: int = 0
659+
renorm: int = 0
660+
outputSoftmaxResultFlag: bool = False
661+
662+
655663
@dataclass
656664
class AclNnMoeInitRoutingParam:
657665
name: str = ""
658-
activeNum: int = 10240
659666
numExperts: int = 0
660667

661668

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,14 +1014,22 @@ def PrepareMoe(name, x, num_experts):
10141014
op.set_output([f"{name}__0", f"{name}__1", f"{name}__2", f"{name}__3"])
10151015
return op
10161016

1017-
def AclNnMoeInitRouting(name, x, row_ids, topk_ids, active_num, num_experts):
1017+
def AclNnMoeGatingTopkSoftmax(name, x, topk):
1018+
op = Operation(name, "AclNnMoeGatingTopkSoftmaxOperation")
1019+
param = infer_param.AclNnMoeGatingTopkSoftmaxParam()
1020+
param.name = name
1021+
param.topk = topk
1022+
op.set_input([x])
1023+
op.set_param(param)
1024+
op.set_output([f"{name}__0", f"{name}__1"])
1025+
return op
1026+
1027+
def AclNnMoeInitRouting(name, x, topk_ids, num_experts):
10181028
op = Operation(name, "AclNnMoeInitRoutingOperation")
10191029
param = infer_param.AclNnMoeInitRoutingParam()
10201030
param.name = name
1021-
param.activeNum = active_num
10221031
param.numExperts = num_experts
1023-
1024-
op.set_input([x, row_ids, topk_ids])
1032+
op.set_input([x, topk_ids])
10251033
op.set_param(param)
10261034
op.set_output([f"{name}__0", f"{name}__1", f"{name}__2"])
10271035
return op
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#include "moe_gating_topk_softmax.h"
2+
3+
#include <cstddef>
4+
5+
#include "aclnnop/aclnn_moe_gating_top_k_softmax_v2.h"
6+
#include "third_party/acl/inc/acl/acl_base.h"
7+
#include "utils/log.h"
8+
9+
namespace dicp {
10+
11+
const int NUM1 = 1;
12+
const int NUM2 = 2;
13+
const int NUM3 = 3;
14+
15+
AclNnMoeGatingTopkSoftmaxOperation::AclNnMoeGatingTopkSoftmaxOperation(const std::string& name, int64_t topk, int64_t renorm, bool outputSoftmaxResultFlag)
16+
: AclNnOperation(name), topk_(topk), renorm_(renorm), outputSoftmaxResultFlag_(outputSoftmaxResultFlag) {}
17+
18+
AclNnMoeGatingTopkSoftmaxOperation::~AclNnMoeGatingTopkSoftmaxOperation() {}
19+
20+
atb::Status AclNnMoeGatingTopkSoftmaxOperation::InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs,
21+
atb::SVector<atb::TensorDesc>& outTensorDescs) const {
22+
DICP_LOG(INFO) << opName_ << " infer shape start";
23+
24+
outTensorDescs.at(0).format = inTensorDescs.at(0).format;
25+
outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum;
26+
outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype;
27+
for (size_t i = 0; i < outTensorDescs.at(0).shape.dimNum; ++i) {
28+
outTensorDescs.at(0).shape.dims[i] = i == outTensorDescs.at(0).shape.dimNum - 1 ? topk_ : inTensorDescs.at(0).shape.dims[i];
29+
}
30+
31+
outTensorDescs.at(1).format = outTensorDescs.at(0).format;
32+
outTensorDescs.at(1).shape.dimNum = outTensorDescs.at(0).shape.dimNum;
33+
outTensorDescs.at(1).dtype = aclDataType::ACL_INT32;
34+
for (size_t i = 0; i < outTensorDescs.at(1).shape.dimNum; ++i) {
35+
outTensorDescs.at(1).shape.dims[i] = outTensorDescs.at(0).shape.dims[i];
36+
}
37+
38+
DICP_LOG(INFO) << opName_ << " infer shape end";
39+
return 0;
40+
}
41+
42+
uint32_t AclNnMoeGatingTopkSoftmaxOperation::GetInputNum() const { return NUM1; }
43+
44+
uint32_t AclNnMoeGatingTopkSoftmaxOperation::GetOutputNum() const { return NUM2; }
45+
46+
int AclNnMoeGatingTopkSoftmaxOperation::SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) {
47+
DICP_LOG(INFO) << opName_ << " aclnnMoeGatingTopKSoftmaxV2GetWorkspaceSize start";
48+
49+
int ret = aclnnMoeGatingTopKSoftmaxV2GetWorkspaceSize(aclInTensors_.at(0).tensor,
50+
nullptr,
51+
topk_,
52+
renorm_,
53+
outputSoftmaxResultFlag_,
54+
aclOutTensors_.at(0).tensor,
55+
aclOutTensors_.at(1).tensor,
56+
nullptr,
57+
&workspaceSize,
58+
&aclExecutor_);
59+
60+
DICP_LOG(INFO) << opName_ << " aclnnMoeGatingTopKSoftmaxV2GetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize
61+
<< ", aclExecutor:" << aclExecutor_;
62+
63+
return ret;
64+
}
65+
66+
int AclNnMoeGatingTopkSoftmaxOperation::CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) {
67+
DICP_LOG(INFO) << opName_ << " aclnnMoeGatingTopKSoftmaxV2 start";
68+
int ret = aclnnMoeGatingTopKSoftmaxV2(workspace, workspaceSize, aclExecutor, stream);
69+
DICP_LOG(INFO) << opName_ << " aclnnMoeGatingTopKSoftmaxV2 end, ret:" << ret;
70+
return ret;
71+
}
72+
73+
atb::Operation* AclNnMoeGatingTopkSoftmaxOperationCreate(const nlohmann::json& paramJson) {
74+
std::string opName;
75+
int64_t topk, renorm;
76+
bool outputSoftmaxResultFlag;
77+
if (paramJson.contains("name")) {
78+
opName = paramJson["name"].get<std::string>();
79+
}
80+
if (paramJson.contains("topk")) {
81+
topk = paramJson["topk"].get<int64_t>();
82+
}
83+
if (paramJson.contains("renorm")) {
84+
renorm = paramJson["renorm"].get<int64_t>();
85+
}
86+
if (paramJson.contains("outputSoftmaxResultFlag")) {
87+
outputSoftmaxResultFlag = paramJson["outputSoftmaxResultFlag"].get<bool>();
88+
}
89+
DICP_LOG(INFO) << "AclNnMoeGatingTopkSoftmaxOperation: name: " << opName << " topk:" << topk << " renorm:" << renorm
90+
<< " outputSoftmaxResultFlag:" << outputSoftmaxResultFlag;
91+
atb::Operation* op = new AclNnMoeGatingTopkSoftmaxOperation(opName, topk, renorm, outputSoftmaxResultFlag);
92+
return op;
93+
}
94+
95+
REGISTER_OPERATION(AclNnMoeGatingTopkSoftmaxOperation, AclNnMoeGatingTopkSoftmaxOperationCreate);
96+
97+
} // namespace dicp
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#include "ops/aclnn_ops/acl_nn_operation.h"
4+
5+
namespace dicp {
6+
7+
class AclNnMoeGatingTopkSoftmaxOperation : public AclNnOperation {
8+
public:
9+
explicit AclNnMoeGatingTopkSoftmaxOperation(const std::string& name, int64_t topk, int64_t renorm, bool outputSoftmaxResultFlag);
10+
~AclNnMoeGatingTopkSoftmaxOperation() override;
11+
atb::Status InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const override;
12+
uint32_t GetInputNum() const override;
13+
uint32_t GetOutputNum() const override;
14+
15+
private:
16+
int64_t topk_;
17+
int64_t renorm_;
18+
bool outputSoftmaxResultFlag_;
19+
int SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) override;
20+
int CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) override;
21+
};
22+
23+
} // namespace dicp

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/moe_init_routing_operation.cpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
#include "moe_init_routing_operation.h"
22

3-
#include "aclnnop/aclnn_moe_init_routing.h"
4-
// #include "aclnnop/aclnn_moe_init_routing_v2.h"
3+
#include "aclnnop/aclnn_moe_init_routing_v2.h"
54
#include "utils/log.h"
65

76
namespace dicp {
87

98
const int NUM1 = 1;
9+
const int NUM2 = 2;
1010
const int NUM3 = 3;
1111

12-
AclNnMoeInitRoutingOperation::AclNnMoeInitRoutingOperation(const std::string& name, int64_t activeNum, int64_t numExperts)
13-
: AclNnOperation(name), activeNum_(activeNum), numExperts_(numExperts) {}
12+
AclNnMoeInitRoutingOperation::AclNnMoeInitRoutingOperation(const std::string& name, int64_t numExperts) : AclNnOperation(name), numExperts_(numExperts) {}
1413

1514
AclNnMoeInitRoutingOperation::~AclNnMoeInitRoutingOperation() {}
1615

1716
atb::Status AclNnMoeInitRoutingOperation::InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const {
1817
DICP_LOG(INFO) << opName_ << " infer shape start";
19-
auto seqLength = inTensorDescs.at(2).shape.dims[0];
20-
auto topk = inTensorDescs.at(2).shape.dims[1];
18+
auto seqLength = inTensorDescs.at(0).shape.dims[0];
19+
auto topk = inTensorDescs.at(1).shape.dims[1];
20+
activeNum_ = seqLength * topk;
2121

2222
outTensorDescs.at(0).format = inTensorDescs.at(0).format;
2323
outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum;
@@ -30,60 +30,61 @@ atb::Status AclNnMoeInitRoutingOperation::InferShape(const atb::SVector<atb::Ten
3030
outTensorDescs.at(1).dtype = inTensorDescs.at(1).dtype;
3131
outTensorDescs.at(1).shape.dims[0] = seqLength * topk;
3232

33-
outTensorDescs.at(2).format = inTensorDescs.at(2).format;
33+
outTensorDescs.at(2).format = inTensorDescs.at(1).format;
3434
outTensorDescs.at(2).shape.dimNum = NUM1;
35-
outTensorDescs.at(2).dtype = inTensorDescs.at(2).dtype;
36-
outTensorDescs.at(2).shape.dims[0] = seqLength * topk;
35+
outTensorDescs.at(2).dtype = inTensorDescs.at(1).dtype;
36+
outTensorDescs.at(2).shape.dims[0] = numExperts_;
3737

3838
DICP_LOG(INFO) << opName_ << " infer shape end";
3939
return 0;
4040
}
4141

42-
uint32_t AclNnMoeInitRoutingOperation::GetInputNum() const { return NUM3; }
42+
uint32_t AclNnMoeInitRoutingOperation::GetInputNum() const { return NUM2; }
4343

4444
uint32_t AclNnMoeInitRoutingOperation::GetOutputNum() const { return NUM3; }
4545

4646
int AclNnMoeInitRoutingOperation::SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) {
47-
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingGetWorkspaceSize start";
48-
49-
int ret = aclnnMoeInitRoutingGetWorkspaceSize(aclInTensors_.at(0).tensor,
50-
aclInTensors_.at(1).tensor,
51-
aclInTensors_.at(2).tensor,
52-
activeNum_,
53-
aclOutTensors_.at(0).tensor,
54-
aclOutTensors_.at(1).tensor,
55-
aclOutTensors_.at(2).tensor,
56-
&workspaceSize,
57-
&aclExecutor_);
58-
59-
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingGetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize
47+
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingV2GetWorkspaceSize start";
48+
49+
int ret = aclnnMoeInitRoutingV2GetWorkspaceSize(aclInTensors_.at(0).tensor,
50+
aclInTensors_.at(1).tensor,
51+
activeNum_,
52+
0,
53+
numExperts_,
54+
0,
55+
1,
56+
false,
57+
aclOutTensors_.at(0).tensor,
58+
aclOutTensors_.at(1).tensor,
59+
aclOutTensors_.at(2).tensor,
60+
nullptr,
61+
&workspaceSize,
62+
&aclExecutor_);
63+
64+
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingV2GetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize
6065
<< ", aclExecutor:" << aclExecutor_;
6166

6267
return ret;
6368
}
6469

6570
int AclNnMoeInitRoutingOperation::CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) {
66-
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRouting start";
67-
int ret = aclnnMoeInitRouting(workspace, workspaceSize, aclExecutor, stream);
68-
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRouting end, ret:" << ret;
71+
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingV2 start";
72+
int ret = aclnnMoeInitRoutingV2(workspace, workspaceSize, aclExecutor, stream);
73+
DICP_LOG(INFO) << opName_ << " aclnnMoeInitRoutingV2 end, ret:" << ret;
6974
return ret;
7075
}
7176

7277
atb::Operation* AclNnMoeInitRoutingOperationCreate(const nlohmann::json& paramJson) {
7378
std::string opName;
74-
int64_t activeNum;
7579
int64_t numExperts;
7680
if (paramJson.contains("name")) {
7781
opName = paramJson["name"].get<std::string>();
7882
}
79-
if (paramJson.contains("activeNum")) {
80-
activeNum = paramJson["activeNum"].get<int64_t>();
81-
}
8283
if (paramJson.contains("numExperts")) {
8384
numExperts = paramJson["numExperts"].get<int64_t>();
8485
}
85-
DICP_LOG(INFO) << "AclNnMoeInitRoutingOperation: name: " << opName << " activeNum:" << activeNum << " numExperts:" << numExperts;
86-
atb::Operation* op = new AclNnMoeInitRoutingOperation(opName, activeNum, numExperts);
86+
DICP_LOG(INFO) << "AclNnMoeInitRoutingOperation: name: " << opName << " numExperts:" << numExperts;
87+
atb::Operation* op = new AclNnMoeInitRoutingOperation(opName, numExperts);
8788
return op;
8889
}
8990

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/moe_init_routing_operation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ namespace dicp {
66

77
class AclNnMoeInitRoutingOperation : public AclNnOperation {
88
public:
9-
explicit AclNnMoeInitRoutingOperation(const std::string& name, int64_t activeNum, int64_t numExperts);
9+
explicit AclNnMoeInitRoutingOperation(const std::string& name, int64_t numExperts);
1010
~AclNnMoeInitRoutingOperation() override;
1111
atb::Status InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const override;
1212
uint32_t GetInputNum() const override;
1313
uint32_t GetOutputNum() const override;
1414

1515
private:
16-
int64_t activeNum_;
16+
mutable int64_t activeNum_;
1717
int64_t numExperts_;
1818
int SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) override;
1919
int CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) override;

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/moe_token_permute_operation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "aclnnop/aclnn_moe_token_permute.h"
44
#include "utils/log.h"
5+
#include "utils/tensor_utils.h"
56

67
namespace dicp {
78

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/utils/tensor_utils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ template <aclDataType T>
139139
void copyAndPrint(const atb::Tensor tensor, int64_t tensorSize) {
140140
using vectorT = typename aclDataTypeMap<T>::type;
141141
std::vector<vectorT> resultData(tensorSize, 0);
142-
auto ret =
143-
aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), tensor.deviceData, tensorSize * sizeof(float16_t), ACL_MEMCPY_DEVICE_TO_HOST);
142+
auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(vectorT), tensor.deviceData, tensorSize * sizeof(vectorT), ACL_MEMCPY_DEVICE_TO_HOST);
144143
for (int64_t i = 0; i < tensorSize; ++i) {
145144
DICP_LOG(INFO) << "data[" << i << "]: " << resultData[i];
146145
}

0 commit comments

Comments
 (0)