@@ -24,19 +24,7 @@ static diopiError_t transpose(diopiContextHandle_t& ctx, DiopiTensor& in, DiopiT
2424 return diopiSuccess;
2525}
2626
27- // static diopiError_t calTensordiopiMemoryFormat_t(const DiopiTensor& tensor, diopiMemoryFormat_t& memoryFormatOut) {
28- // if (tensor.isContiguous(diopiMemoryFormat_t::ChannelsLast)) {
29- // memoryFormatOut = diopiMemoryFormat_t::ChannelsLast;
30- // } else if (tensor.isContiguous(diopiMemoryFormat_t::ChannelsLast3d)) {
31- // memoryFormatOut = diopiMemoryFormat_t::ChannelsLast3d;
32- // } else if (tensor.isContiguous(diopiMemoryFormat_t::Contiguous)) {
33- // memoryFormatOut = diopiMemoryFormat_t::Contiguous;
34- // } else {
35- // return diopiNoImplement;
36- // }
37- // return diopiSuccess;
38- // }
39- static diopiError_t getPermuteOrder (const DiopiTensor& src, std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
27+ diopiError_t getPermuteOrder (const DiopiTensor& src, std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
4028 if (src.isContiguous ()) {
4129 orderOut.resize (src.dim ());
4230 for (int i = 0 ; i < src.dim (); ++i) {
@@ -59,6 +47,7 @@ static diopiError_t getPermuteOrder(const DiopiTensor& src, std::vector<int32_t>
5947 stridesSizes[i] = std::pair<int , int >(inputStrides[i], inputSizes[i]);
6048 }
6149
50+ // shape:2,3,4,5 stride:60,1,15,3 -> orderOut: 0,3,1,2, reverseOrder: 0,2,3,1
6251 sort (stridesSizes.begin (), stridesSizes.end (), [](std::pair<int , int > a, std::pair<int , int > b) { return a.first > b.first ; });
6352 for (int i = 0 ; i < dim; ++i) {
6453 auto pair = stridesSizes[i];
@@ -83,73 +72,6 @@ static diopiError_t getPermuteOrder(const DiopiTensor& src, std::vector<int32_t>
8372 return diopiSuccess;
8473}
8574
86- static diopiError_t calOrderAndSrcMemoryFormat (const DiopiTensor& src, diopiMemoryFormat_t destMemoryFormat, diopiMemoryFormat_t& srcMemoryFormatOut,
87- std::vector<int32_t >& orderOut, std::vector<int32_t >& reverseOrder) {
88- if (src.isContiguous (destMemoryFormat)) {
89- srcMemoryFormatOut = destMemoryFormat;
90- orderOut.resize (src.dim ());
91- for (int i = 0 ; i < src.dim (); ++i) {
92- orderOut[i] = i;
93- }
94- reverseOrder = orderOut;
95- return diopiSuccess;
96- }
97- if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast1d) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
98- if (src.dim () != 3 ) {
99- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
100- return diopiNoImplement;
101- }
102- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast1d;
103- orderOut = {0 , 2 , 1 };
104- reverseOrder = {0 , 2 , 1 };
105- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast1d) {
106- if (src.dim () != 3 ) {
107- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
108- return diopiNoImplement;
109- }
110- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
111- orderOut = {0 , 2 , 1 };
112- reverseOrder = {0 , 2 , 1 };
113- } else if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
114- if (src.dim () != 4 ) {
115- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
116- return diopiNoImplement;
117- }
118- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast;
119- orderOut = {0 , 3 , 1 , 2 };
120- reverseOrder = {0 , 2 , 3 , 1 };
121- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast) {
122- if (src.dim () != 4 ) {
123- setLastErrorString (" the dim of the tensor should be 4, but now is %d." , src.dim ());
124- return diopiNoImplement;
125- }
126- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
127- orderOut = {0 , 2 , 3 , 1 };
128- reverseOrder = {0 , 3 , 1 , 2 };
129- } else if (src.isContiguous (diopiMemoryFormat_t::Contiguous) && destMemoryFormat == diopiMemoryFormat_t::ChannelsLast3d) {
130- if (src.dim () != 5 ) {
131- setLastErrorString (" the dim of the tensor should be 5, but now is %d." , src.dim ());
132- return diopiNoImplement;
133- }
134- srcMemoryFormatOut = diopiMemoryFormat_t::Contiguous;
135- orderOut = {0 , 2 , 3 , 4 , 1 };
136- reverseOrder = {0 , 4 , 1 , 2 , 3 };
137- } else if (src.isContiguous (diopiMemoryFormat_t::ChannelsLast3d) && destMemoryFormat == diopiMemoryFormat_t::Contiguous) {
138- if (src.dim () != 5 ) {
139- setLastErrorString (" the dim of the tensor should be 5, but now is %d." , src.dim ());
140- return diopiNoImplement;
141- }
142- srcMemoryFormatOut = diopiMemoryFormat_t::ChannelsLast3d;
143- orderOut = {0 , 4 , 1 , 2 , 3 };
144- reverseOrder = {0 , 2 , 3 , 4 , 1 };
145- } else {
146- // convert to contiguous format
147- srcMemoryFormatOut = diopiMemoryFormat_t::Preserve;
148- return diopiSuccess;
149- }
150- return diopiSuccess;
151- }
152-
15375diopiError_t calCnnlLayout (diopiMemoryFormat_t memoryFormat, int64_t dim, cnnlTensorLayout_t& cnnlLayout) {
15476 switch (memoryFormat) {
15577 case diopiMemoryFormat_t::ChannelsLast1d:
@@ -234,68 +156,61 @@ diopiError_t contiguous(diopiContextHandle_t ctx, DiopiTensor& src, diopiMemoryF
234156
235157 int64_t dim = src.dim ();
236158 DIOPI_CHECK (dim <= 8 , " only support less than 8d tensor currently" );
237- diopiMemoryFormat_t srcMemoryFormat;
238- std::vector<int32_t > order;
239- std::vector<int32_t > reverseOrder;
240159 DiopiTensor dest;
241- DIOPI_CALL (calOrderAndSrcMemoryFormat (src, memoryFormat, srcMemoryFormat, order, reverseOrder));
242- if (srcMemoryFormat == diopiMemoryFormat_t::Preserve) {
243- DIOPI_CALL (clone (ctx, src, dest, memoryFormat));
244- src = dest;
245- return diopiSuccess;
246- }
247- dest = requiresTensor (ctx, src.shape (), src.dtype (), memoryFormat);
248- // set CNNL_LAYOUT_ARRAY because NLC->NCL failed ( no layout NCL);
249- cnnlTensorLayout_t srcLayout = CNNL_LAYOUT_ARRAY;
250- cnnlTensorLayout_t destLayout = CNNL_LAYOUT_ARRAY;
251-
252- std::vector<int64_t > olderDestStride = dest.stride ();
253- std::vector<int64_t > olderDestShape = dest.shape ();
254- if (memoryFormat != diopiMemoryFormat_t::Contiguous) {
255- DIOPI_CALL (permuteTensor (dest, order));
256- } else {
257- DIOPI_CALL (permuteTensor (src, reverseOrder));
258- }
259- DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, order));
260- // recovery the shape
261- dest.asStrided (olderDestShape, olderDestStride);
160+ DIOPI_CALL (clone (ctx, src, dest, memoryFormat));
262161 src = dest;
263162 return diopiSuccess;
264163}
265164
266- // inplace contiguous
267- diopiError_t contiguousOut (diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dest) {
165+ diopiError_t permuteCopy (diopiContextHandle_t ctx, DiopiTensor& src, DiopiTensor& dest) {
166+ // using input permute + output permute + cnnltranspose to copy
268167 DIOPI_CHECK (src.shape () == dest.shape (), " src's shape should be the same as dest's" );
269168 int64_t dim = src.dim ();
270169 DIOPI_CHECK (dim <= 8 , " only support less than 8d tensor currently" );
271- std::vector<int32_t > order (dim, 0 );
272- std::vector<int32_t > reverseOrder (dim, 0 );
170+ bool srcIsContiguous = src.isContiguous ();
171+ bool destIsContiguous = dest.isContiguous ();
172+ std::vector<int32_t > inputOrder (dim, 0 );
173+ std::vector<int32_t > inputBackOrder (dim, 0 ); // permuteTensor(input,inputBackOrder)->contiguous
174+ std::vector<int32_t > outputOrder (dim, 0 );
175+ std::vector<int32_t > outputBackOrder (dim, 0 ); // permuteTensor(output,outputBackOrder)->contiguous
176+ std::vector<int32_t > inputToOutputOrder (dim, 0 ); // into cnnltranspose
177+
178+ // input shape:2,3,4,5 stride:60,1,15,3 -> inputBackOrder: 0,2,3,1, inputOrder: 0,3,1,2
179+ // output shape:2,3,4,5 stride:60,20,1,4 -> outputBackOrder: 0,1,3,2, outputOrder: 0,1,3,2
180+ // inputToOutputOrder: 0,2,1,3
181+
182+ getPermuteOrder (src, inputOrder, inputBackOrder);
183+ getPermuteOrder (dest, outputOrder, outputBackOrder);
273184
274- if (src.isContiguous ()) {
275- getPermuteOrder (dest, reverseOrder, order);
276- } else {
277- getPermuteOrder (src, order, reverseOrder);
278- }
279- // set CNNL_LAYOUT_ARRAY because NLC->NCL failed ( no layout NCL);
280185 cnnlTensorLayout_t srcLayout = CNNL_LAYOUT_ARRAY;
281186 cnnlTensorLayout_t destLayout = CNNL_LAYOUT_ARRAY;
282187
283188 std::vector<int64_t > olderDestStride = dest.stride ();
284189 std::vector<int64_t > olderDestShape = dest.shape ();
285190 std::vector<int64_t > olderSrcStride = src.stride ();
286191 std::vector<int64_t > olderSrcShape = src.shape ();
287- // if (destMemoryFormat != diopiMemoryFormat_t::Contiguous) {
288- if (src.isContiguous ()) {
289- DIOPI_CALL (permuteTensor (dest, order));
290- } else {
291- DIOPI_CALL (permuteTensor (src, reverseOrder));
192+
193+ // permute to get contiguous tensor
194+ if (!destIsContiguous) {
195+ DIOPI_CALL (permuteTensor (dest, outputBackOrder));
196+ }
197+
198+ if (!srcIsContiguous) {
199+ DIOPI_CALL (permuteTensor (src, inputBackOrder));
200+ }
201+
202+ for (int i = 0 ; i < dim; ++i) {
203+ inputToOutputOrder[i] = inputOrder[outputBackOrder[i]];
292204 }
293- DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, order));
205+
206+ DIOPI_CALL (transpose (ctx, src, dest, srcLayout, destLayout, inputToOutputOrder));
207+
294208 // recovery the shape and strides
295- // if (destMemoryFormat != diopiMemoryFormat_t::Contiguous) {
296- if (src.isContiguous ()) {
209+ if (!destIsContiguous) {
297210 dest.asStrided (olderDestShape, olderDestStride);
298- } else {
211+ }
212+
213+ if (!srcIsContiguous) {
299214 src.asStrided (olderSrcShape, olderSrcStride);
300215 }
301216 return diopiSuccess;
0 commit comments