Skip to content

Commit f1d1a60

Browse files
committed
fix sub bug for ascend
1 parent f984eba commit f1d1a60

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,27 @@
7676
custom_code_at_the_beginning: |
7777
return dipu_add__tensor(self, other, -alpha);
7878
79+
- schema: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
80+
custom_code_at_the_beginning: |
81+
at::Tensor out = UnaryOpInferrer().infer_out(self);
82+
interface: diopiSubScalar(ctx, out, self, other, alpha)
83+
7984
- schema: "sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
80-
dummy_call_diopi: True
85+
ins: [selfTmp]
8186
custom_code_at_the_beginning: |
82-
at::native::sub_check(self, other);
83-
auto out = BinaryOpInferrer().infer_out(self, other);
84-
return dipu_add_out(self, other, -alpha, out);
87+
if (is_scalar_on_cpu(other)) {
88+
return dipu_sub_scalar(self, other.item(), alpha);
89+
}
90+
91+
at::Tensor selfTmp = self;
92+
if (is_scalar_on_cpu(selfTmp)) {
93+
selfTmp = selfTmp.to(other.device());
94+
}
95+
96+
at::native::sub_check(selfTmp, other);
97+
at::Tensor out = BinaryOpInferrer().infer_out(selfTmp, other);
98+
99+
interface: diopiSub(ctx, out, selfTmp, other, alpha)
85100

86101
- schema: "div.Scalar(Tensor self, Scalar other) -> Tensor"
87102
custom_code_at_the_beginning: |

0 commit comments

Comments
 (0)