Skip to content

Commit ee54eac

Browse files
Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT).
1 parent a1f36ee commit ee54eac

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

src/diffusers/models/transformers/transformer_prx.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
532532
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
533533
// patch_size)` is the number of patches.
534534
"""
535-
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
535+
b, c, h, w = img.shape
536+
p = patch_size
537+
538+
# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
539+
img = img.reshape(b, c, h // p, p, w // p, p)
540+
541+
# Permute to (B, H//p, W//p, C, p, p) using einsum
542+
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
543+
img = torch.einsum("nchpwq->nhwcpq", img)
544+
545+
# Flatten to (B, L, C * p * p)
546+
img = img.reshape(b, -1, c * p * p)
547+
return img
536548

537549

538550
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
@@ -554,12 +566,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
554566
Reconstructed image tensor of shape `(B, C, H, W)`.
555567
"""
556568
if isinstance(shape, tuple):
557-
shape = shape[-2:]
569+
h, w = shape[-2:]
558570
elif isinstance(shape, torch.Tensor):
559-
shape = (int(shape[0]), int(shape[1]))
571+
h, w = (int(shape[0]), int(shape[1]))
560572
else:
561573
raise NotImplementedError(f"shape type {type(shape)} not supported")
562-
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
574+
575+
b, l, d = seq.shape
576+
p = patch_size
577+
c = d // (p * p)
578+
579+
# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
580+
seq = seq.reshape(b, h // p, w // p, c, p, p)
581+
582+
# Permute back to image layout: (B, C, H//p, p, W//p, p)
583+
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
584+
seq = torch.einsum("nhwcpq->nchpwq", seq)
585+
586+
# Final reshape to (B, C, H, W)
587+
seq = seq.reshape(b, c, h, w)
588+
return seq
563589

564590

565591
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):

0 commit comments

Comments
 (0)