@@ -88,6 +88,22 @@ def shape_of(tensor):
8888 return tensor .shape
8989 raise ValueError ("Unsupported type: {}" .format (type (tensor )))
9090
91+ @staticmethod
92+ def _is_no_bias (bias ):
93+ """Check if bias represents 'no bias' condition.
94+
95+ This handles both Python None and relax.op.null_value() expressions
96+ that might be used to represent missing bias parameters.
97+ """
98+ if bias is None :
99+ return True
100+
101+ # Check if this is a null_value expression
102+ if isinstance (bias , relax .Call ) and bias .op .name == "relax.null_value" :
103+ return True
104+
105+ return False
106+
91107 def retrieve_args (self , node : fx .Node ):
92108 return self ._retrieve_args (node .args )
93109
@@ -103,7 +119,7 @@ def _retrieve_args(self, node):
103119 elif isinstance (node , dict ):
104120 return {self ._retrieve_args (k ): self ._retrieve_args (v ) for k , v in node .items ()}
105121 elif node is None :
106- return relax . op . null_value ()
122+ return None
107123 else :
108124 return node
109125
@@ -758,7 +774,7 @@ def _conv_transpose1d_impl(
758774 )
759775 )
760776
761- if bias is None :
777+ if self . _is_no_bias ( bias ) :
762778 return conv1d_transpose
763779
764780 assert len (self .shape_of (bias )) == 1
@@ -812,7 +828,7 @@ def _conv_transpose2d_impl(
812828 )
813829 )
814830
815- if bias is None :
831+ if self . _is_no_bias ( bias ) :
816832 return conv2d_transpose
817833
818834 assert len (self .shape_of (bias )) == 1
@@ -864,7 +880,7 @@ def _conv1d_impl(
864880 )
865881 )
866882
867- if bias is None :
883+ if self . _is_no_bias ( bias ) :
868884 return conv1d
869885 assert len (self .shape_of (bias )) == 1
870886 bias = relax .op .reshape (bias , (1 , - 1 , 1 ))
@@ -913,7 +929,7 @@ def _conv2d_impl(
913929 )
914930 )
915931
916- if bias is None :
932+ if self . _is_no_bias ( bias ) :
917933 return conv2d
918934 assert len (self .shape_of (bias )) == 1
919935 bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 ))
@@ -962,7 +978,7 @@ def _conv3d_impl(
962978 )
963979 )
964980
965- if bias is None :
981+ if self . _is_no_bias ( bias ) :
966982 return conv3d
967983 assert len (self .shape_of (bias )) == 1
968984 bias = relax .op .reshape (bias , (1 , - 1 , 1 , 1 , 1 ))
0 commit comments