Skip to content

Commit 6b6a314

Browse files
authored
Arm backend: Fix decomposition logic for FP (#15694)
If the tosa_spec supports floating point, incorrect quantization is not a problem, and we can always not decompose the ops_to_not_decompose. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Erik Lundell <[email protected]>
1 parent 747fc6f commit 6b6a314

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

backends/arm/tosa/partitioner.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -335,19 +335,24 @@ def ops_to_not_decompose(
335335
function that returns True when an op should not be decomposed.
336336
337337
"""
338-
ops_to_not_decompose_if_quant_op = [
338+
ops_to_not_decompose_if_quant_op = {
339339
torch.ops.aten.hardsigmoid.default,
340340
torch.ops.aten.hardswish.default,
341341
torch.ops.aten.linear.default,
342-
]
342+
}
343+
ops_to_not_decompose_if_fp = {
344+
torch.ops.aten.linear.default,
345+
}
346+
ops_to_not_decompose_always = {
347+
torch.ops.aten.eye.default,
348+
torch.ops.aten.linspace.default,
349+
torch.ops.aten.logit.default,
350+
}
343351

344352
def filter_fn(node: torch.fx.Node) -> bool:
345-
"""Return True to keep selected ops intact inside quantized regions.
346-
347-
The predicate holds when the target is in
348-
``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
349-
quantize/dequantize ops, indicating a quantized activation that
350-
should not be decomposed.
353+
"""Filter function applied to ops in 'ops_to_not_decompose'.
354+
Returns True if the op should not be decomposed.
355+
If this function returns True, the partitioner *must* accept the node, or the lowering fails.
351356
352357
Args:
353358
node (torch.fx.Node): FX node to evaluate.
@@ -356,6 +361,12 @@ def filter_fn(node: torch.fx.Node) -> bool:
356361
bool: True to keep the op intact; otherwise, False.
357362
358363
"""
364+
if (
365+
self.tosa_spec.support_float()
366+
and node.target in ops_to_not_decompose_if_fp
367+
):
368+
return True
369+
359370
dq = (
360371
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
361372
torch.ops.quantized_decomposed.dequantize_per_channel.default,
@@ -394,11 +405,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
394405
# By default, do not decompose the operator
395406
return True
396407

397-
ops_to_not_decompose = [
398-
torch.ops.aten.eye.default,
399-
torch.ops.aten.linspace.default,
400-
torch.ops.aten.logit.default,
401-
] + ops_to_not_decompose_if_quant_op
408+
ops_to_not_decompose = list(
409+
ops_to_not_decompose_always
410+
| ops_to_not_decompose_if_quant_op
411+
| ops_to_not_decompose_if_fp
412+
)
402413

403414
if not self.tosa_spec.is_U55_subset:
404415
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d

0 commit comments

Comments
 (0)