Skip to content

Commit fcb7872

Browse files
junjiang-labcopybara-github
authored andcommitted
Add decomp for tfl.fill.
PiperOrigin-RevId: 802639397
1 parent 9af9286 commit fcb7872

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ def _aten_cat_decomp(tensors, dim=0):
243243
return torch.ops.tfl.concatenation(processed_tensors, dim)
244244

245245

246+
@register_decomp(torch.ops.aten.full.default)
247+
def _aten_full_decomp(
248+
size,
249+
fill_value,
250+
dtype=None,
251+
layout=None,
252+
device=None,
253+
pin_memory=None,
254+
):
255+
return torch.ops.tfl.fill(tuple(size), fill_value)
256+
257+
246258
@register_decomp(torch.ops.aten.full_like.default)
247259
def _aten_full_like_decomp(
248260
x,

ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,11 @@ def _tfl_batch_matmul_lowering(
7272
@lower(torch.ops.tfl.add.default)
7373
def _tfl_add_lowering(
7474
lctx: LoweringContext,
75-
lhs: ir.Value,
75+
lhs: ir.Value | int | float,
7676
rhs: ir.Value | int | float,
7777
fused_activation_function: str = "NONE",
7878
) -> ir.Value:
79+
lhs = lowering_utils.convert_to_ir_value(lhs)
7980
rhs = lowering_utils.convert_to_ir_value(rhs)
8081
return _ir_operation(
8182
"tfl.add",

ai_edge_torch/odml_torch/experimental/torch_tfl/test/test_torch_tfl_impls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def _assert_export_and_close(
179179
("aten_cat_2", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0, 10))], 0,), dict()),
180180
("aten_cat_3", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (0,))], 0,), dict()),
181181
("aten_cat_4", torch.ops.aten.cat.default, ([rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10))],), dict()),
182+
("aten_full_0", torch.ops.aten.full.default, ([10, 10], 0.123,), dict()),
183+
("aten_full_1", torch.ops.aten.full.default, ([10, 10], 123,), dict()),
182184
("aten_full_like_0", torch.ops.aten.full_like.default, (rnd(torch.float32, (10, 10)), 0.123,), dict()),
183185
("aten_full_like_1", torch.ops.aten.full_like.default, (rnd(torch.int64, (10, 10)), 123,), dict()),
184186
("aten_view_0", torch.ops.aten.view.default, (rnd(torch.float32, (10, 10)), [1, 100],), dict()),

0 commit comments

Comments
 (0)