@@ -532,7 +532,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
532532 Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
533533 // patch_size)` is the number of patches.
534534 """
535- return unfold (img , kernel_size = patch_size , stride = patch_size ).transpose (1 , 2 )
535+ b , c , h , w = img .shape
536+ p = patch_size
537+
538+ # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
539+ img = img .reshape (b , c , h // p , p , w // p , p )
540+
541+ # Permute to (B, H//p, W//p, C, p, p) using einsum
542+ # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
543+ img = torch .einsum ("nchpwq->nhwcpq" , img )
544+
545+ # Flatten to (B, L, C * p * p)
546+ img = img .reshape (b , - 1 , c * p * p )
547+ return img
536548
537549
538550def seq2img (seq : torch .Tensor , patch_size : int , shape : torch .Tensor ) -> torch .Tensor :
@@ -554,12 +566,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
554566 Reconstructed image tensor of shape `(B, C, H, W)`.
555567 """
556568 if isinstance (shape , tuple ):
557- shape = shape [- 2 :]
569+ h , w = shape [- 2 :]
558570 elif isinstance (shape , torch .Tensor ):
559- shape = (int (shape [0 ]), int (shape [1 ]))
571+ h , w = (int (shape [0 ]), int (shape [1 ]))
560572 else :
561573 raise NotImplementedError (f"shape type { type (shape )} not supported" )
562- return fold (seq .transpose (1 , 2 ), shape , kernel_size = patch_size , stride = patch_size )
574+
575+ b , l , d = seq .shape
576+ p = patch_size
577+ c = d // (p * p )
578+
579+ # Reshape back to grid structure: (B, H//p, W//p, C, p, p)
580+ seq = seq .reshape (b , h // p , w // p , c , p , p )
581+
582+ # Permute back to image layout: (B, C, H//p, p, W//p, p)
583+ # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
584+ seq = torch .einsum ("nhwcpq->nchpwq" , seq )
585+
586+ # Final reshape to (B, C, H, W)
587+ seq = seq .reshape (b , c , h , w )
588+ return seq
563589
564590
565591class PRXTransformer2DModel (ModelMixin , ConfigMixin , AttentionMixin ):
0 commit comments