Skip to content

Commit 2e5ac7d

Browse files
authored
[Relax][PyTorch] improve the check for no bias situation (#18374)
* [#18373] improve the check for no bias situation * [#18373] refactor the _normalize_python_tuple
1 parent 43e9c27 commit 2e5ac7d

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

python/tvm/relax/block_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ def _normalize_python_tuple(self, expr: Union[Expr, Sequence[Expr]]):
299299
"""
300300
if isinstance(expr, (list, tuple)):
301301
return Tuple([self._normalize_python_tuple(element) for element in expr])
302+
elif expr is None:
303+
from . import op
304+
305+
return op.null_value()
302306
else:
303307
return expr
304308

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,22 @@ def shape_of(tensor):
8888
return tensor.shape
8989
raise ValueError("Unsupported type: {}".format(type(tensor)))
9090

91+
@staticmethod
92+
def _is_no_bias(bias):
93+
"""Check if bias represents 'no bias' condition.
94+
95+
This handles both Python None and relax.op.null_value() expressions
96+
that might be used to represent missing bias parameters.
97+
"""
98+
if bias is None:
99+
return True
100+
101+
# Check if this is a null_value expression
102+
if isinstance(bias, relax.Call) and bias.op.name == "relax.null_value":
103+
return True
104+
105+
return False
106+
91107
def retrieve_args(self, node: fx.Node):
92108
return self._retrieve_args(node.args)
93109

@@ -103,7 +119,7 @@ def _retrieve_args(self, node):
103119
elif isinstance(node, dict):
104120
return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()}
105121
elif node is None:
106-
return relax.op.null_value()
122+
return None
107123
else:
108124
return node
109125

@@ -758,7 +774,7 @@ def _conv_transpose1d_impl(
758774
)
759775
)
760776

761-
if bias is None:
777+
if self._is_no_bias(bias):
762778
return conv1d_transpose
763779

764780
assert len(self.shape_of(bias)) == 1
@@ -812,7 +828,7 @@ def _conv_transpose2d_impl(
812828
)
813829
)
814830

815-
if bias is None:
831+
if self._is_no_bias(bias):
816832
return conv2d_transpose
817833

818834
assert len(self.shape_of(bias)) == 1
@@ -864,7 +880,7 @@ def _conv1d_impl(
864880
)
865881
)
866882

867-
if bias is None:
883+
if self._is_no_bias(bias):
868884
return conv1d
869885
assert len(self.shape_of(bias)) == 1
870886
bias = relax.op.reshape(bias, (1, -1, 1))
@@ -913,7 +929,7 @@ def _conv2d_impl(
913929
)
914930
)
915931

916-
if bias is None:
932+
if self._is_no_bias(bias):
917933
return conv2d
918934
assert len(self.shape_of(bias)) == 1
919935
bias = relax.op.reshape(bias, (1, -1, 1, 1))
@@ -962,7 +978,7 @@ def _conv3d_impl(
962978
)
963979
)
964980

965-
if bias is None:
981+
if self._is_no_bias(bias):
966982
return conv3d
967983
assert len(self.shape_of(bias)) == 1
968984
bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))

0 commit comments

Comments
 (0)