Skip to content

Commit 6ccdb45

Browse files
authored
[TIR] Refactor division simplification in RewriteSimplifier (#18319)
* Refactor division simplification in RewriteSimplifier and add corresponding test This commit removes the specific case for rewriting division by a constant float in the RewriteSimplifier. Additionally, a new test is introduced to verify the behavior of float division simplification, ensuring that the division is correctly handled without the previous rewrite logic. * test fix * test fix * cifix * fix
1 parent 70c157d commit 6ccdb45

File tree

7 files changed

+170
-169
lines changed

7 files changed

+170
-169
lines changed

src/arith/rewrite_simplify.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -774,13 +774,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
774774
// Pattern var for lanes in broadcast and ramp
775775
PVar<PrimExpr> lanes;
776776

777-
// x / 2.0 = x * 0.5
778-
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
779-
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
780-
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
781-
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
782-
}
783-
784777
// Vector rules
785778
if (op->dtype.is_scalable_or_fixed_length_vector()) {
786779
// NOTE: use div as the pattern also works for float.

tests/python/arith/test_arith_simplify.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import tvm.testing
2222
from tvm import tir
2323
from tvm.script import tir as T
24+
import tvm.ir
2425

2526

2627
def test_simplify_reshape_flattened_index():
@@ -144,5 +145,16 @@ def test_simplify_floor_mod_with_linear_offset():
144145
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
145146

146147

148+
def test_simplify_float_division():
149+
# Test for the discussion:
150+
# https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615
151+
ana = tvm.arith.Analyzer()
152+
x = tir.Var("x", "float32")
153+
ry = x / 27
154+
# in old version, the division will be rewritten into x * T.float32(1 / 27)
155+
sy = ana.rewrite_simplify(ry)
156+
tvm.ir.assert_structural_equal(ry, sy)
157+
158+
147159
if __name__ == "__main__":
148160
tvm.testing.main()

tests/python/relax/test_codegen_cudnn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, with_bias, activation):
193193
out = get_result_with_relax_cudnn_offload(mod, args)
194194
ref = build_and_run(mod, args, "llvm", legalize=True)
195195
if dtype == "float16":
196-
tvm.testing.assert_allclose(out, ref, rtol=1e-1, atol=1e-1)
196+
# FIXME(lei): currently raise into 3e-1 to prevent flaky test
197+
# see https://github.com/apache/tvm/pull/18319
198+
tvm.testing.assert_allclose(out, ref, rtol=3e-1, atol=3e-1)
197199
else:
198200
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
199201

tests/python/relax/test_op_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def test_arange_infer_struct_info_shape_var():
661661
_check_inference(
662662
bb,
663663
relax.op.arange(start, stop, 2),
664-
relax.TensorStructInfo((T.cast(T.ceil((stop - start) * 0.5), "int64"),), "float32"),
664+
relax.TensorStructInfo((T.cast(T.ceil((stop - start) / 2), "int64"),), "float32"),
665665
)
666666
_check_inference(
667667
bb,

tests/python/relax/test_transform_legalize_ops_nn.py

Lines changed: 145 additions & 151 deletions
Large diffs are not rendered by default.

tests/python/relax/test_transform_legalize_ops_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def quantize(
212212
"int8",
213213
T.max(
214214
T.min(
215-
T.round(A[v_i0, v_i1] * T.float32(0.5)) + T.float32(1),
215+
T.round(A[v_i0, v_i1] / T.float32(2)) + T.float32(1),
216216
T.float32(127),
217217
),
218218
T.float32(-128),
@@ -311,7 +311,7 @@ def quantize(
311311
"int8",
312312
T.max(
313313
T.min(
314-
T.round(A[v_i0, v_i1] * T.float16(0.5)) + T.float16(1),
314+
T.round(A[v_i0, v_i1] / T.float16(2)) + T.float16(1),
315315
T.float16(127),
316316
),
317317
T.float16(-128),

tests/python/relax/test_transform_legalize_ops_search_statistical.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)
627627
ax0, ax1 = T.axis.remap("SS", [i0, i1])
628628
T.reads(rxplaceholder_red[ax0, ax1])
629629
T.writes(T_divide[ax0, ax1])
630-
T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] * T.float32(0.1)
630+
T_divide[ax0, ax1] = rxplaceholder_red[ax0, ax1] / T.float32(10)
631631
# fmt: on
632632

633633
mod = LegalizeOps()(Mean)
@@ -718,7 +718,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))
718718
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
719719
T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
720720
T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
721-
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.0083333333333333332)
721+
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(120.0)
722722
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
723723
with T.block("T_subtract"):
724724
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -743,7 +743,7 @@ def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))
743743
vi = T.axis.spatial(1, T.int64(0))
744744
T.reads(T_multiply_red[()])
745745
T.writes(T_divide_1[()])
746-
T_divide_1[()] = T_multiply_red[()] * T.float32(0.0083333333333333332)
746+
T_divide_1[()] = T_multiply_red[()] / T.float32(120.0)
747747
with T.block("compute"):
748748
vi = T.axis.spatial(1, T.int64(0))
749749
T.reads(T_divide_1[()])
@@ -881,7 +881,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6
881881
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
882882
T.reads(rxplaceholder_red[ax0, ax1, ax2, ax3])
883883
T.writes(T_divide_1[ax0, ax1, ax2, ax3])
884-
T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001)
884+
T_divide_1[ax0, ax1, ax2, ax3] = rxplaceholder_red[ax0, ax1, ax2, ax3] / T.float32(10.0)
885885
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
886886
with T.block("T_subtract"):
887887
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
@@ -907,7 +907,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6
907907
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
908908
T.reads(T_multiply_red[ax0, ax1, ax2, ax3])
909909
T.writes(T_divide[ax0, ax1, ax2, ax3])
910-
T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] * T.float32(0.10000000000000001)
910+
T_divide[ax0, ax1, ax2, ax3] = T_multiply_red[ax0, ax1, ax2, ax3] / T.float32(10)
911911
# fmt: on
912912

913913
mod = LegalizeOps()(Variance)
@@ -1027,7 +1027,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6
10271027
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
10281028
T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
10291029
T.writes(T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3])
1030-
T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.10000000000000001)
1030+
T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.float32(10)
10311031
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)):
10321032
with T.block("T_subtract"):
10331033
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
@@ -1053,7 +1053,7 @@ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int6
10531053
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
10541054
T.reads(T_multiply_red[v_ax0, v_ax1])
10551055
T.writes(T_divide[v_ax0, v_ax1])
1056-
T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] * T.float32(0.10000000000000001)
1056+
T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] / T.float32(10)
10571057

10581058
@R.function
10591059
def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"):

0 commit comments

Comments
 (0)