Skip to content

Commit eca3301

Browse files
committed
Fix issue with pytorch 1.8 nightly
1 parent c23bd48 commit eca3301

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

pthflops/ops.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,40 @@ def string_to_shape(node_string, bias=False):
2727
return m if m is None else tuple(int(x) for x in m.groups()[0].split(','))
2828

2929

30-
def _parse_node_inputs(node):
30+
def _parse_node_inputs(node, version=2):
3131
inputs = {}
3232
inputs_names = []
3333
for idx, inp in enumerate(node.inputs()):
3434
inp = str(inp)
3535
curr_node_name = re.search(r'(.*) defined in ', inp).group(1)
36-
extracted_data = re.search(r'%' + curr_node_name + r' : Float\(([^%]*)\)[,| ]', inp)
36+
if version <= 2:
37+
extracted_data = re.search(r'%' + curr_node_name + r' : Float\(([^%]*)\)[,| ]', inp)
38+
elif version == 3:
39+
extracted_data = re.search(r'%' + curr_node_name + r' : Float\(((\d+, )+)', inp)
40+
3741
if extracted_data is not None:
3842
extracted_data = extracted_data.group(1)
3943
else:
4044
return _parse_node_inputs(
4145
list(node.inputs())[0].node()
4246
)
43-
inputs[curr_node_name] = re.findall(r'(\d+):', extracted_data)
47+
if version <= 2:
48+
inputs[curr_node_name] = re.findall(r'(\d+):', extracted_data)
49+
elif version == 3:
50+
inputs[curr_node_name] = re.findall(r'(\d+)', extracted_data)
4451
inputs[curr_node_name] = list(map(int, inputs[curr_node_name]))
4552
inputs_names.append(curr_node_name)
4653
return inputs, inputs_names
4754

4855

49-
def parse_node_info(node):
50-
inputs, inputs_names = _parse_node_inputs(node)
56+
def parse_node_info(node, version=2):
57+
inputs, inputs_names = _parse_node_inputs(node, version=version)
5158
node = str(node)
5259
node_name = re.search(r'%(.*) : ', node).group(1)
53-
out_size = re.search(r'Float\(\d+:(\d+),', node).group(1)
60+
if version == 2:
61+
out_size = re.search(r'Float\(\d+:(\d+),', node).group(1)
62+
elif version == 3:
63+
out_size = re.search(r'strides=\[(\d+),', node).group(1)
5464

5565
return node_name, inputs, inputs_names, int(out_size)
5666

@@ -75,8 +85,8 @@ def _count_convNd(node, version=2):
7585
out_ops = reduce(lambda x, y: x * y, out)
7686
bias_ops = 1 if string_to_shape(list(node.inputs())[0], True) is not None else 0
7787
f_in = inp[1]
78-
elif version == 2:
79-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
88+
elif version in [2, 3]:
89+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
8090
f_in = inputs[inputs_names[0]][1]
8191
bias_ops = 1 if len(inputs_names) == 3 else 0
8292

@@ -103,8 +113,8 @@ def _count_relu(node, version=2):
103113
"""
104114
if version == 1:
105115
inp = string_to_shape(list(node.inputs())[0])
106-
elif version == 2:
107-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
116+
elif version in [2, 3]:
117+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
108118
inp = inputs[inputs_names[0]]
109119
total_ops = 2 * reduce(lambda x, y: x * y, inp) # also count the comparison
110120
return total_ops
@@ -121,8 +131,8 @@ def _count_avgpool(node, version=2):
121131
if version == 1:
122132
out = string_to_shape(list(node.outputs())[0])
123133
out_ops = reduce(lambda x, y: x * y, out)
124-
elif version == 2:
125-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
134+
elif version in [2, 3]:
135+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
126136

127137
ops_add = reduce(lambda x, y: x * y, node['kernel_shape']) - 1
128138
ops_div = 1
@@ -142,8 +152,8 @@ def _count_globalavgpool(node, version=2):
142152
inp = string_to_shape(list(node.inputs())[0])
143153
out = string_to_shape(list(node.outputs())[0])
144154
out_ops = reduce(lambda x, y: x * y, out)
145-
elif version == 2:
146-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
155+
elif version in [2, 3]:
156+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
147157
inp = inputs[inputs_names[0]]
148158

149159
ops_add = reduce(lambda x, y: x * y, [inp[-2], inp[-1]]) - 1
@@ -163,8 +173,8 @@ def _count_maxpool(node, version=2):
163173
if version == 1:
164174
out = string_to_shape(list(node.outputs())[0])
165175
out_ops = reduce(lambda x, y: x * y, out)
166-
elif version == 2:
167-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
176+
elif version in [2, 3]:
177+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
168178

169179
ops_add = reduce(lambda x, y: x * y, node['kernel_shape']) - 1
170180
total_ops = ops_add * out_ops
@@ -184,8 +194,8 @@ def _count_bn(node, version=2):
184194
inp = string_to_shape(list(node.inputs())[1])
185195
else:
186196
inp = string_to_shape(list(node.inputs())[0])
187-
elif version == 2:
188-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
197+
elif version in [2, 3]:
198+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
189199
inp = inputs[inputs_names[0]]
190200

191201
total_ops = reduce(lambda x, y: x * y, inp) * 2
@@ -205,8 +215,8 @@ def _count_linear(node, version=2):
205215
out = string_to_shape(list(node.outputs())[0])
206216
f_in = inp[1]
207217
out_ops = reduce(lambda x, y: x * y, out)
208-
elif version == 2:
209-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
218+
elif version in [2, 3]:
219+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
210220
inp = inputs[inputs_names[0]]
211221
f_in = inputs[inputs_names[0]][1]
212222

@@ -224,8 +234,8 @@ def _count_add_mul(node, version=2):
224234
"""
225235
if version == 1:
226236
inp = string_to_shape(list(node.inputs())[0])
227-
elif version == 2:
228-
node_name, inputs, inputs_names, out_ops = parse_node_info(node)
237+
elif version in [2, 3]:
238+
node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version)
229239
inp = inputs[inputs_names[0]]
230240
return reduce(lambda x, y: x * y, inp)
231241

@@ -296,8 +306,11 @@ def count_ops(model, input, custom_ops={}, ignore_layers=[], print_readable=True
296306
trace, _ = torch.jit._get_trace_graph(model, input, *args)
297307
graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
298308

299-
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
309+
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0') and \
310+
LooseVersion(torch.__version__) < LooseVersion('1.8.0'):
300311
version = 2
312+
else:
313+
version = 3
301314
else:
302315
# PyTorch 1.3 and bellow
303316
trace, _ = torch.jit.get_trace_graph(model, input, *args)

0 commit comments

Comments
 (0)