|
76 | 76 | custom_code_at_the_beginning: | |
77 | 77 | return dipu_add__tensor(self, other, -alpha); |
78 | 78 |
|
| 79 | +- schema: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor |
| 80 | + custom_code_at_the_beginning: | |
| 81 | + at::Tensor out = UnaryOpInferrer().infer_out(self); |
| 82 | + interface: diopiSubScalar(ctx, out, self, other, alpha) |
| 83 | + |
79 | 84 | - schema: "sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" |
80 | | - dummy_call_diopi: True |
| 85 | + ins: [selfTmp] |
81 | 86 | custom_code_at_the_beginning: | |
82 | | - at::native::sub_check(self, other); |
83 | | - auto out = BinaryOpInferrer().infer_out(self, other); |
84 | | - return dipu_add_out(self, other, -alpha, out); |
| 87 | + if (is_scalar_on_cpu(other)) { |
| 88 | + return dipu_sub_scalar(self, other.item(), alpha); |
| 89 | + } |
| 90 | +
|
| 91 | + at::Tensor selfTmp = self; |
| 92 | + if (is_scalar_on_cpu(selfTmp)) { |
| 93 | + selfTmp = selfTmp.to(other.device()); |
| 94 | + } |
| 95 | +
|
| 96 | + at::native::sub_check(selfTmp, other); |
| 97 | + at::Tensor out = BinaryOpInferrer().infer_out(selfTmp, other); |
| 98 | +
|
| 99 | + interface: diopiSub(ctx, out, selfTmp, other, alpha) |
85 | 100 |
|
86 | 101 | - schema: "div.Scalar(Tensor self, Scalar other) -> Tensor" |
87 | 102 | custom_code_at_the_beginning: | |
|
0 commit comments