Skip to content

Commit 01e42ea

Browse files
authored
Fix vl on 310p device (#230)
1 parent 9290484 commit 01e42ea

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

dlinfer/vendor/ascend/torch_npu_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,16 @@ def linear(
527527
bias=bias,
528528
)
529529
else:
530+
# on 310p, the weight is transposed to nz format in llm part on graph mode,
531+
# but in vl part, eager mode is used.
532+
# we need to reshape it back to nd.
533+
if (
534+
len(weight.shape) == 4
535+
and weight.shape[0] == 1
536+
and weight.shape[1] * weight.shape[3] == x.shape[-1]
537+
):
538+
weight = weight.permute(0, 2, 1, 3)
539+
weight = weight.reshape(weight.shape[1], -1)
530540
out = torch.nn.functional.linear(x, weight, bias)
531541
return out
532542

0 commit comments

Comments
 (0)