Skip to content

Commit 818c742

Browse files
authored
[ascend] feat: support moe in graph mode (#98)
1 parent 966455d commit 818c742

32 files changed

+929
-24
lines changed

dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def infer_result(self, x, weight, bias):
3737
return out
3838

3939

40+
class AllReduce(Operator):
41+
def __init__(self):
42+
super().__init__("AllReduce")
43+
44+
def infer_result(self, x, reduce_type):
45+
return torch.ops._c10d_functional.all_reduce.default(x, reduce_type, "0")
46+
47+
4048
class Add(Operator):
4149
def __init__(self):
4250
super().__init__("Add")
@@ -353,3 +361,48 @@ def __init__(self):
353361

354362
def infer_result(self, x1, x2, axis):
355363
return torch.ops.aten.embedding.default(x1, x2, axis)
364+
365+
366+
class Softmax(Operator):
367+
def __init__(self):
368+
super().__init__("Softmax")
369+
370+
def infer_result(self, x, dim):
371+
return torch.softmax(x, dim=self.dim)
372+
373+
374+
class Sort(Operator):
375+
def __init__(self):
376+
super().__init__("Sort")
377+
378+
def infer_result(self, x, topk):
379+
value, index = torch.topk(x, topk)
380+
return value, index
381+
382+
383+
class Slice(Operator):
384+
def __init__(self):
385+
super().__init__("Slice")
386+
387+
def infer_result(self, x, dim, offsets, size):
388+
return torch.ops.aten.slice.Tensor(
389+
x, dim, offsets[dim], offsets[dim] + size[dim], 1
390+
)
391+
392+
393+
class AclNnSlice(Operator):
394+
def __init__(self):
395+
super().__init__("AclNnSlice")
396+
397+
def infer_result(self, x, dim, start, end, step):
398+
return torch.ops.aten.slice.Tensor(x, dim, start, end, step)
399+
400+
401+
class IndexSelect(Operator):
402+
def __init__(self):
403+
super().__init__("IndexSelect")
404+
405+
def infer_result(self, x, dim, index):
406+
indices = [None] * len(x.shape)
407+
indices[dim] = index
408+
return torch.ops.aten.index.Tensor(x, indices)

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ class LinearParallelParam:
435435
rankSize: int = 0
436436
rankRoot: int = 0
437437
hasResidual: bool = False
438-
backend: str = "hccl"
438+
backend: str = "lccl"
439439
commMode: CommMode = CommMode.COMM_MULTI_PROCESS
440440
rankTableFile: str = ""
441441
parallelType: ParallelType = ParallelType.LINEAR_ALL_REDUCE
@@ -446,6 +446,58 @@ class LinearParallelParam:
446446
commDomain: str = ""
447447

448448

449+
class AllReducQuantType(IntEnum):
450+
QUANT_TYPE_UNDEFINED = 0
451+
QUANT_TYPE_PER_TENSOR = 1
452+
QUANT_TYPE_PER_CHANNEL = 2
453+
QUANT_TYPE_MAX = 3
454+
455+
456+
@dataclass
457+
class AllReduceParam:
458+
rank: int = 0
459+
rankSize: int = 0
460+
rankRoot: int = 0
461+
allReduceType: str = "sum"
462+
backend: str = "lccl"
463+
quantType: QuantType = AllReducQuantType.QUANT_TYPE_UNDEFINED
464+
rankTableFile: str = ""
465+
outDataType: AclDataType = AclDataType.ACL_DT_UNDEFINED
466+
commMode: CommMode = CommMode.COMM_MULTI_PROCESS
467+
commDomain = ""
468+
469+
470+
@dataclass
471+
class SortParam:
472+
num: int = 0
473+
474+
475+
@dataclass
476+
class SoftmaxParam:
477+
axes: list[int] = field(default_factory=list)
478+
479+
480+
@dataclass
481+
class SliceParam:
482+
offsets: list[int] = field(default_factory=list)
483+
size: list[int] = field(default_factory=list)
484+
485+
486+
@dataclass
487+
class AclNnSliceParam:
488+
name: str = ""
489+
dim: int = 0
490+
start: int = 0
491+
end: int = 0
492+
step: int = 0
493+
494+
495+
@dataclass
496+
class IndexSelectParam:
497+
name: str = ""
498+
dim: int = 0
499+
500+
449501
def custom_asdict_factory(data):
450502
def convert_value(obj):
451503
if isinstance(obj, IntEnum):

dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def LinearAllReduce(name, x, weight, bias):
5454
param.rankRoot = 0
5555
param.hasResidual = False
5656
param.parallelType = infer_param.ParallelType.LINEAR_ALL_REDUCE
57+
param.backend = "lccl"
5758

5859
if bias:
5960
op.set_input([x, weight, bias])
@@ -63,7 +64,19 @@ def LinearAllReduce(name, x, weight, bias):
6364
op.set_output([name])
6465
return op
6566

66-
@staticmethod
67+
def AllReduce(name, x, reduce_type):
68+
op = Operation(name, "AllReduceOperation")
69+
param = infer_param.AllReduceParam()
70+
param.rank = dist.get_rank()
71+
param.rankSize = dist.get_world_size()
72+
param.rankRoot = 0
73+
param.allReduceType = reduce_type
74+
param.backend = "lccl"
75+
op.set_input([x])
76+
op.set_param(param)
77+
op.set_output([name])
78+
return op
79+
6780
def Add(name, x, y):
6881
op = Operation(name, "ElewiseOperation")
6982
param = infer_param.ElewiseParam()
@@ -537,3 +550,62 @@ def Gather(name, x1, x2, axis):
537550
op.set_param(param)
538551
op.set_output([name])
539552
return op
553+
554+
def Softmax(name, x, dim):
555+
op = Operation(name, "AclNnSoftmaxOperation")
556+
param = infer_param.SoftmaxParam()
557+
param.name = name
558+
if not isinstance(dim, list):
559+
dim = [dim]
560+
param.axes = dim
561+
562+
op.set_input([x])
563+
op.set_param(param)
564+
op.set_output([name])
565+
return op
566+
567+
def Sort(name, x, topk):
568+
op = Operation(name, "AclNnTopkOperation")
569+
param = infer_param.SortParam()
570+
param.num = topk
571+
572+
op.set_input([x])
573+
op.set_param(param)
574+
op.set_output([f"{name}__0", f"{name}__1"])
575+
return op
576+
577+
def Slice(name, x, dim, offsets, size):
578+
op = Operation(name, "SliceOperation")
579+
param = infer_param.SliceParam()
580+
param.offsets = offsets
581+
param.size = size
582+
583+
op.set_input([x])
584+
op.set_param(param)
585+
op.set_output([name])
586+
return op
587+
588+
def AclNnSlice(name, x, dim, start, end, step):
589+
op = Operation(name, "AclNnSliceOperation")
590+
param = infer_param.AclNnSliceParam()
591+
param.name = name
592+
param.dim = dim
593+
param.start = start
594+
param.end = end
595+
param.step = step
596+
597+
op.set_input([x])
598+
op.set_param(param)
599+
op.set_output([name])
600+
return op
601+
602+
def IndexSelect(name, x, dim, index):
603+
op = Operation(name, "AclNnIndexSelectOperation")
604+
param = infer_param.IndexSelectParam()
605+
param.name = name
606+
param.dim = dim
607+
608+
op.set_input([x, index])
609+
op.set_param(param)
610+
op.set_output([name])
611+
return op

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ target_include_directories(
3232
${CMAKE_CURRENT_SOURCE_DIR}
3333
${TORCH_NPU_INCLUDE_DIRS}
3434
${CANN_INCLUDE_DIRS}
35+
${CANN_INCLUDE_DIRS}/aclnn
3536
${ATB_INCLUDE_DIRS}
3637
)
3738

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ atb::Tensor Model::CreateInternalTensorFromDesc(const atb::TensorDesc& tensorDes
109109
Model::Model(const std::string& modelId, const std::string& modelPath) : modelId_(modelId), modelPath_(modelPath) {
110110
auto st = BuildGraph();
111111
DICP_LOG_IF(st != atb::NO_ERROR, ERROR) << modelId_ << " init graph:\n" << graph_.ToString();
112-
113112
graph_.Init();
114113
DICP_LOG(INFO) << modelId_ << " init graph:\n" << graph_.ToString();
115114
}
@@ -249,7 +248,6 @@ void Model::BuildNodeVariantPack(int nodeId) {
249248
if (needReshape) {
250249
node.inTensorReshapeFuncs.at(i)(node.inTensors.at(i)->desc.shape, inTensorDescs.at(i).shape);
251250
node.variantPack.inTensors.at(i).desc.shape = inTensorDescs.at(i).shape;
252-
node.inTensors.at(i)->desc.shape = inTensorDescs.at(i).shape;
253251
}
254252
DICP_LOG(INFO) << modelId_ << " nodes[" << nodeId << "] inTensors[" << i << "]:" << tensor_utils::TensorToString(node.variantPack.inTensors.at(i));
255253
}
@@ -265,7 +263,7 @@ void Model::BuildNodeVariantPack(int nodeId) {
265263
for (size_t i = 0; i < node.outTensors.size(); ++i) {
266264
if (hasInplaceOutputs && node.inplaceIndices.count(i) > 0) {
267265
auto inputIdx = node.inplaceIndices[i];
268-
node.variantPack.outTensors.at(i) = *node.inTensors.at(inputIdx);
266+
node.variantPack.outTensors.at(i) = node.variantPack.inTensors.at(inputIdx);
269267
*node.outTensors.at(i) = node.variantPack.outTensors.at(i);
270268
continue;
271269
}
@@ -494,7 +492,8 @@ void Model::SetupUnsqueezeReshape(const nlohmann::json& reshapeInput, atb::Resha
494492
func = [=](const atb::Dims& oldShape, atb::Dims& newShape) {
495493
std::vector<int64_t> dimValues(oldShape.dims, oldShape.dims + oldShape.dimNum);
496494
for (const auto& d : dims) {
497-
dimValues.insert(dimValues.begin() + d, 1);
495+
int offset = d < 0 ? d + oldShape.dimNum + 1 : d;
496+
dimValues.insert(dimValues.begin() + offset, 1);
498497
}
499498
newShape.dimNum = dimValues.size();
500499
std::copy(dimValues.begin(), dimValues.end(), newShape.dims);
@@ -506,7 +505,8 @@ void Model::SetupSqueezeReshape(const nlohmann::json& reshapeInput, atb::Reshape
506505
func = [=](const atb::Dims& oldShape, atb::Dims& newShape) {
507506
std::vector<int64_t> dimValues(oldShape.dims, oldShape.dims + oldShape.dimNum);
508507
for (const auto& d : dims) {
509-
dimValues.erase(dimValues.begin() + d);
508+
int offset = d < 0 ? d + oldShape.dimNum : d;
509+
dimValues.erase(dimValues.begin() + offset);
510510
}
511511
newShape.dimNum = dimValues.size();
512512
std::copy(dimValues.begin(), dimValues.end(), newShape.dims);

dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/gt_scalar_operation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ AclNnGtScalarOperation::~AclNnGtScalarOperation() {
2222
atb::Status AclNnGtScalarOperation::InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const {
2323
DICP_LOG(INFO) << opName_ << " infer shape start";
2424
outTensorDescs.at(0).format = inTensorDescs.at(0).format;
25-
outTensorDescs.at(0).shape.dimNum = NUM1;
26-
outTensorDescs.at(0).shape.dims[0] = 1;
25+
outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum;
2726
outTensorDescs.at(0).dtype = aclDataType::ACL_BOOL;
27+
for (size_t i = 0; i < outTensorDescs.at(0).shape.dimNum; ++i) {
28+
outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i];
29+
}
2830
DICP_LOG(INFO) << opName_ << " infer shape end";
2931
return 0;
3032
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "index_select_operation.h"
2+
3+
#include "aclnnop/aclnn_index_select.h"
4+
#include "utils/log.h"
5+
6+
namespace dicp {
7+
8+
const int NUM1 = 1;
9+
const int NUM2 = 2;
10+
11+
AclNnIndexSelectOperation::AclNnIndexSelectOperation(const std::string& name, int64_t dim) : AclNnOperation(name), dim_(dim) {}
12+
13+
AclNnIndexSelectOperation::~AclNnIndexSelectOperation() {}
14+
15+
atb::Status AclNnIndexSelectOperation::InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const {
16+
DICP_LOG(INFO) << opName_ << " infer shape start";
17+
outTensorDescs.at(0).format = inTensorDescs.at(0).format;
18+
outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum;
19+
outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype;
20+
21+
for (size_t i = 0; i < outTensorDescs.at(0).shape.dimNum; ++i) {
22+
outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i];
23+
}
24+
outTensorDescs.at(0).shape.dims[dim_] = inTensorDescs.at(1).shape.dims[0];
25+
DICP_LOG(INFO) << opName_ << " infer shape end";
26+
return 0;
27+
}
28+
29+
uint32_t AclNnIndexSelectOperation::GetInputNum() const { return NUM2; }
30+
31+
uint32_t AclNnIndexSelectOperation::GetOutputNum() const { return NUM1; }
32+
33+
int AclNnIndexSelectOperation::SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) {
34+
DICP_LOG(INFO) << opName_ << " AclNnIndexSelectGetWorkspaceSize start";
35+
36+
int ret = aclnnIndexSelectGetWorkspaceSize(
37+
aclInTensors_.at(0).tensor, dim_, aclInTensors_.at(1).tensor, aclOutTensors_.at(0).tensor, &workspaceSize, &aclExecutor_);
38+
DICP_LOG(INFO) << opName_ << " AclNnIndexSelectGetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize
39+
<< ", aclExecutor:" << aclExecutor_;
40+
41+
return ret;
42+
}
43+
44+
int AclNnIndexSelectOperation::CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) {
45+
DICP_LOG(INFO) << opName_ << " AclNnIndexSelect start";
46+
int ret = aclnnIndexSelect(workspace, workspaceSize, aclExecutor, stream);
47+
DICP_LOG(INFO) << opName_ << " AclNnIndexSelect end, ret:" << ret;
48+
return ret;
49+
}
50+
51+
} // namespace dicp
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#pragma once
2+
3+
#include "acl_nn_operation.h"
4+
5+
namespace dicp {
6+
7+
class AclNnIndexSelectOperation : public AclNnOperation {
8+
public:
9+
explicit AclNnIndexSelectOperation(const std::string& name, int64_t dim);
10+
~AclNnIndexSelectOperation() override;
11+
atb::Status InferShape(const atb::SVector<atb::TensorDesc>& inTensorDescs, atb::SVector<atb::TensorDesc>& outTensorDescs) const override;
12+
uint32_t GetInputNum() const override;
13+
uint32_t GetOutputNum() const override;
14+
15+
private:
16+
int64_t dim_;
17+
int SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) override;
18+
int CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) override;
19+
};
20+
21+
inline atb::Operation* AclNnIndexSelectOperationCreate(const nlohmann::json& paramJson) {
22+
std::string opName;
23+
int64_t dim;
24+
std::string dtype;
25+
if (paramJson.contains("name")) {
26+
opName = paramJson["name"].get<std::string>();
27+
}
28+
if (paramJson.contains("dim")) {
29+
dim = paramJson["dim"].get<int64_t>();
30+
}
31+
DICP_LOG(INFO) << "AclNnIndexSelectOperation: name: " << opName << " dim:" << dim;
32+
atb::Operation* op = new AclNnIndexSelectOperation(opName, dim);
33+
return op;
34+
}
35+
36+
} // namespace dicp

0 commit comments

Comments
 (0)