@@ -335,19 +335,24 @@ def ops_to_not_decompose(
335335 function that returns True when an op should not be decomposed.
336336
337337 """
338- ops_to_not_decompose_if_quant_op = [
338+ ops_to_not_decompose_if_quant_op = {
339339 torch .ops .aten .hardsigmoid .default ,
340340 torch .ops .aten .hardswish .default ,
341341 torch .ops .aten .linear .default ,
342- ]
342+ }
343+ ops_to_not_decompose_if_fp = {
344+ torch .ops .aten .linear .default ,
345+ }
346+ ops_to_not_decompose_always = {
347+ torch .ops .aten .eye .default ,
348+ torch .ops .aten .linspace .default ,
349+ torch .ops .aten .logit .default ,
350+ }
343351
344352 def filter_fn (node : torch .fx .Node ) -> bool :
345- """Return True to keep selected ops intact inside quantized regions.
346-
347- The predicate holds when the target is in
348- ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are
349- quantize/dequantize ops, indicating a quantized activation that
350- should not be decomposed.
353+ """Filter function applied to ops in 'ops_to_not_decompose'.
354+ Returns True if the op should not be decomposed.
355+ If this function returns True, the partitioner *must* accept the node, or the lowering fails.
351356
352357 Args:
353358 node (torch.fx.Node): FX node to evaluate.
@@ -356,6 +361,12 @@ def filter_fn(node: torch.fx.Node) -> bool:
356361 bool: True to keep the op intact; otherwise, False.
357362
358363 """
364+ if (
365+ self .tosa_spec .support_float ()
366+ and node .target in ops_to_not_decompose_if_fp
367+ ):
368+ return True
369+
359370 dq = (
360371 torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
361372 torch .ops .quantized_decomposed .dequantize_per_channel .default ,
@@ -394,11 +405,11 @@ def filter_fn(node: torch.fx.Node) -> bool:
394405 # By default, do not decompose the operator
395406 return True
396407
397- ops_to_not_decompose = [
398- torch . ops . aten . eye . default ,
399- torch . ops . aten . linspace . default ,
400- torch . ops . aten . logit . default ,
401- ] + ops_to_not_decompose_if_quant_op
408+ ops_to_not_decompose = list (
409+ ops_to_not_decompose_always
410+ | ops_to_not_decompose_if_quant_op
411+ | ops_to_not_decompose_if_fp
412+ )
402413
403414 if not self .tosa_spec .is_U55_subset :
404415 # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
0 commit comments