@@ -27,30 +27,40 @@ def string_to_shape(node_string, bias=False):
2727 return m if m is None else tuple (int (x ) for x in m .groups ()[0 ].split (',' ))
2828
2929
30- def _parse_node_inputs (node ):
30+ def _parse_node_inputs (node , version = 2 ):
3131 inputs = {}
3232 inputs_names = []
3333 for idx , inp in enumerate (node .inputs ()):
3434 inp = str (inp )
3535 curr_node_name = re .search (r'(.*) defined in ' , inp ).group (1 )
36- extracted_data = re .search (r'%' + curr_node_name + r' : Float\(([^%]*)\)[,| ]' , inp )
36+ if version <= 2 :
37+ extracted_data = re .search (r'%' + curr_node_name + r' : Float\(([^%]*)\)[,| ]' , inp )
38+ elif version == 3 :
39+ extracted_data = re .search (r'%' + curr_node_name + r' : Float\(((\d+, )+)' , inp )
40+
3741 if extracted_data is not None :
3842 extracted_data = extracted_data .group (1 )
3943 else :
4044 return _parse_node_inputs (
4145 list (node .inputs ())[0 ].node ()
4246 )
43- inputs [curr_node_name ] = re .findall (r'(\d+):' , extracted_data )
47+ if version <= 2 :
48+ inputs [curr_node_name ] = re .findall (r'(\d+):' , extracted_data )
49+ elif version == 3 :
50+ inputs [curr_node_name ] = re .findall (r'(\d+)' , extracted_data )
4451 inputs [curr_node_name ] = list (map (int , inputs [curr_node_name ]))
4552 inputs_names .append (curr_node_name )
4653 return inputs , inputs_names
4754
4855
49- def parse_node_info (node ):
50- inputs , inputs_names = _parse_node_inputs (node )
56+ def parse_node_info (node , version = 2 ):
57+ inputs , inputs_names = _parse_node_inputs (node , version = version )
5158 node = str (node )
5259 node_name = re .search (r'%(.*) : ' , node ).group (1 )
53- out_size = re .search (r'Float\(\d+:(\d+),' , node ).group (1 )
60+ if version == 2 :
61+ out_size = re .search (r'Float\(\d+:(\d+),' , node ).group (1 )
62+ elif version == 3 :
63+ out_size = re .search (r'strides=\[(\d+),' , node ).group (1 )
5464
5565 return node_name , inputs , inputs_names , int (out_size )
5666
@@ -75,8 +85,8 @@ def _count_convNd(node, version=2):
7585 out_ops = reduce (lambda x , y : x * y , out )
7686 bias_ops = 1 if string_to_shape (list (node .inputs ())[0 ], True ) is not None else 0
7787 f_in = inp [1 ]
78- elif version == 2 :
79- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
88+ elif version in [ 2 , 3 ] :
89+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
8090 f_in = inputs [inputs_names [0 ]][1 ]
8191 bias_ops = 1 if len (inputs_names ) == 3 else 0
8292
@@ -103,8 +113,8 @@ def _count_relu(node, version=2):
103113 """
104114 if version == 1 :
105115 inp = string_to_shape (list (node .inputs ())[0 ])
106- elif version == 2 :
107- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
116+ elif version in [ 2 , 3 ] :
117+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
108118 inp = inputs [inputs_names [0 ]]
109119 total_ops = 2 * reduce (lambda x , y : x * y , inp ) # also count the comparison
110120 return total_ops
@@ -121,8 +131,8 @@ def _count_avgpool(node, version=2):
121131 if version == 1 :
122132 out = string_to_shape (list (node .outputs ())[0 ])
123133 out_ops = reduce (lambda x , y : x * y , out )
124- elif version == 2 :
125- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
134+ elif version in [ 2 , 3 ] :
135+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
126136
127137 ops_add = reduce (lambda x , y : x * y , node ['kernel_shape' ]) - 1
128138 ops_div = 1
@@ -142,8 +152,8 @@ def _count_globalavgpool(node, version=2):
142152 inp = string_to_shape (list (node .inputs ())[0 ])
143153 out = string_to_shape (list (node .outputs ())[0 ])
144154 out_ops = reduce (lambda x , y : x * y , out )
145- elif version == 2 :
146- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
155+ elif version in [ 2 , 3 ] :
156+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
147157 inp = inputs [inputs_names [0 ]]
148158
149159 ops_add = reduce (lambda x , y : x * y , [inp [- 2 ], inp [- 1 ]]) - 1
@@ -163,8 +173,8 @@ def _count_maxpool(node, version=2):
163173 if version == 1 :
164174 out = string_to_shape (list (node .outputs ())[0 ])
165175 out_ops = reduce (lambda x , y : x * y , out )
166- elif version == 2 :
167- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
176+ elif version in [ 2 , 3 ] :
177+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
168178
169179 ops_add = reduce (lambda x , y : x * y , node ['kernel_shape' ]) - 1
170180 total_ops = ops_add * out_ops
@@ -184,8 +194,8 @@ def _count_bn(node, version=2):
184194 inp = string_to_shape (list (node .inputs ())[1 ])
185195 else :
186196 inp = string_to_shape (list (node .inputs ())[0 ])
187- elif version == 2 :
188- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
197+ elif version in [ 2 , 3 ] :
198+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
189199 inp = inputs [inputs_names [0 ]]
190200
191201 total_ops = reduce (lambda x , y : x * y , inp ) * 2
@@ -205,8 +215,8 @@ def _count_linear(node, version=2):
205215 out = string_to_shape (list (node .outputs ())[0 ])
206216 f_in = inp [1 ]
207217 out_ops = reduce (lambda x , y : x * y , out )
208- elif version == 2 :
209- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
218+ elif version in [ 2 , 3 ] :
219+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
210220 inp = inputs [inputs_names [0 ]]
211221 f_in = inputs [inputs_names [0 ]][1 ]
212222
@@ -224,8 +234,8 @@ def _count_add_mul(node, version=2):
224234 """
225235 if version == 1 :
226236 inp = string_to_shape (list (node .inputs ())[0 ])
227- elif version == 2 :
228- node_name , inputs , inputs_names , out_ops = parse_node_info (node )
237+ elif version in [ 2 , 3 ] :
238+ node_name , inputs , inputs_names , out_ops = parse_node_info (node , version = version )
229239 inp = inputs [inputs_names [0 ]]
230240 return reduce (lambda x , y : x * y , inp )
231241
@@ -296,8 +306,11 @@ def count_ops(model, input, custom_ops={}, ignore_layers=[], print_readable=True
296306 trace , _ = torch .jit ._get_trace_graph (model , input , * args )
297307 graph = torch .onnx ._optimize_trace (trace , torch .onnx .OperatorExportTypes .ONNX )
298308
299- if LooseVersion (torch .__version__ ) >= LooseVersion ('1.6.0' ):
309+ if LooseVersion (torch .__version__ ) >= LooseVersion ('1.6.0' ) and \
310+ LooseVersion (torch .__version__ ) < LooseVersion ('1.8.0' ):
300311 version = 2
312+ else :
313+ version = 3
301314 else :
302315 # PyTorch 1.3 and bellow
303316 trace , _ = torch .jit .get_trace_graph (model , input , * args )
0 commit comments