Skip to content

Commit 72cad39

Browse files
authored
Squeezenet reshape outputs fix (#10222)
@AyushExel Signed-off-by: Glenn Jocher <[email protected]> Signed-off-by: Glenn Jocher <[email protected]>
1 parent 40bb803 commit 72cad39

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

utils/torch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def reshape_classifier_output(model, n=1000):
8282
elif nn.Conv2d in types:
8383
i = types.index(nn.Conv2d) # nn.Conv2d index
8484
if m[i].out_channels != n:
85-
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
85+
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
8686

8787

8888
@contextmanager

0 commit comments

Comments
 (0)