1+ import argparse
2+ import torch
3+ import torch .nn as nn
4+ import torch .nn .functional as F
5+ import torch .optim as optim
6+ from torch .optim .lr_scheduler import StepLR
7+ from torchvision import datasets , transforms
8+
9+ # ---------- Core Swin Components ----------
10+
11+ class PatchEmbed (nn .Module ):
12+ def __init__ (self , img_size = 32 , patch_size = 4 , in_chans = 3 , embed_dim = 48 ):
13+ super ().__init__ ()
14+ self .proj = nn .Conv2d (in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
15+ self .norm = nn .LayerNorm (embed_dim )
16+
17+ def forward (self , x ):
18+ x = self .proj (x )
19+ x = x .flatten (2 ).transpose (1 , 2 )
20+ x = self .norm (x )
21+ return x
22+
23+ def window_partition (x , window_size ):
24+ B , H , W , C = x .shape
25+ x = x .view (B , H // window_size , window_size , W // window_size , window_size , C )
26+ windows = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size , window_size , C )
27+ return windows
28+
29+ def window_reverse (windows , window_size , H , W ):
30+ B = int (windows .shape [0 ] / (H * W / window_size / window_size ))
31+ x = windows .view (B , H // window_size , W // window_size , window_size , window_size , - 1 )
32+ x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (B , H , W , - 1 )
33+ return x
34+
35+ class WindowAttention (nn .Module ):
36+ def __init__ (self , dim , window_size , num_heads ):
37+ super ().__init__ ()
38+ self .num_heads = num_heads
39+ head_dim = dim // num_heads
40+ self .scale = head_dim ** - 0.5
41+
42+ self .qkv = nn .Linear (dim , dim * 3 )
43+ self .proj = nn .Linear (dim , dim )
44+
45+ def forward (self , x ):
46+ B_ , N , C = x .shape
47+ qkv = self .qkv (x ).reshape (B_ , N , 3 , self .num_heads , C // self .num_heads )
48+ q , k , v = qkv .permute (2 , 0 , 3 , 1 , 4 )
49+
50+ attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
51+ attn = attn .softmax (dim = - 1 )
52+
53+ out = (attn @ v ).transpose (1 , 2 ).reshape (B_ , N , C )
54+ return self .proj (out )
55+
56+ class SwinTransformerBlock (nn .Module ):
57+ def __init__ (self , dim , input_resolution , num_heads , window_size = 4 , shift_size = 0 ):
58+ super ().__init__ ()
59+ self .dim = dim
60+ self .input_resolution = input_resolution
61+ self .window_size = window_size
62+ self .shift_size = shift_size
63+
64+ self .norm1 = nn .LayerNorm (dim )
65+ self .attn = WindowAttention (dim , window_size , num_heads )
66+ self .norm2 = nn .LayerNorm (dim )
67+
68+ self .mlp = nn .Sequential (
69+ nn .Linear (dim , dim * 4 ),
70+ nn .GELU (),
71+ nn .Linear (dim * 4 , dim )
72+ )
73+
74+ def forward (self , x ):
75+ H , W = self .input_resolution
76+ B , L , C = x .shape
77+ x = x .view (B , H , W , C )
78+
79+ if self .shift_size > 0 :
80+ shifted_x = torch .roll (x , (- self .shift_size , - self .shift_size ), (1 , 2 ))
81+ else :
82+ shifted_x = x
83+
84+ windows = window_partition (shifted_x , self .window_size )
85+ windows = windows .view (- 1 , self .window_size * self .window_size , C )
86+
87+ attn_windows = self .attn (self .norm1 (windows ))
88+ attn_windows = attn_windows .view (- 1 , self .window_size , self .window_size , C )
89+
90+ shifted_x = window_reverse (attn_windows , self .window_size , H , W )
91+
92+ if self .shift_size > 0 :
93+ x = torch .roll (shifted_x , (self .shift_size , self .shift_size ), (1 , 2 ))
94+ else :
95+ x = shifted_x
96+
97+ x = x .view (B , H * W , C )
98+ x = x + self .mlp (self .norm2 (x ))
99+ return x
100+
101+ # ---------- Final Network ----------
102+
103+ class SwinTinyNet (nn .Module ):
104+ def __init__ (self , num_classes = 10 ):
105+ super (SwinTinyNet , self ).__init__ ()
106+ self .patch_embed = PatchEmbed (img_size = 32 , patch_size = 4 , in_chans = 3 , embed_dim = 48 )
107+ self .block1 = SwinTransformerBlock (dim = 48 , input_resolution = (8 , 8 ), num_heads = 3 , window_size = 4 , shift_size = 0 )
108+ self .block2 = SwinTransformerBlock (dim = 48 , input_resolution = (8 , 8 ), num_heads = 3 , window_size = 4 , shift_size = 2 )
109+ self .norm = nn .LayerNorm (48 )
110+ self .fc = nn .Linear (48 , num_classes )
111+
112+ def forward (self , x ):
113+ x = self .patch_embed (x )
114+ x = self .block1 (x )
115+ x = self .block2 (x )
116+ x = self .norm (x )
117+ x = x .mean (dim = 1 )
118+ x = self .fc (x )
119+ return F .log_softmax (x , dim = 1 )
120+
121+ # ---------- Training and Testing ----------
122+
123+ def train (args , model , device , train_loader , optimizer , epoch ):
124+ model .train ()
125+ for batch_idx , (data , target ) in enumerate (train_loader ):
126+ data , target = data .to (device ), target .to (device )
127+ optimizer .zero_grad ()
128+ output = model (data )
129+ loss = F .nll_loss (output , target )
130+ loss .backward ()
131+ optimizer .step ()
132+ if batch_idx % args .log_interval == 0 :
133+ print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
134+ epoch , batch_idx * len (data ), len (train_loader .dataset ),
135+ 100. * batch_idx / len (train_loader ), loss .item ()))
136+ if args .dry_run :
137+ break
138+
139+ def test (args , model , device , test_loader ):
140+ model .eval ()
141+ test_loss = 0
142+ correct = 0
143+ with torch .no_grad ():
144+ for data , target in test_loader :
145+ data , target = data .to (device ), target .to (device )
146+ output = model (data )
147+ test_loss += F .nll_loss (output , target , reduction = 'sum' ).item ()
148+ pred = output .argmax (dim = 1 , keepdim = True )
149+ correct += pred .eq (target .view_as (pred )).sum ().item ()
150+ if args .dry_run :
151+ break
152+
153+ test_loss /= len (test_loader .dataset )
154+ print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
155+ test_loss , correct , len (test_loader .dataset ),
156+ 100. * correct / len (test_loader .dataset )))
157+
158+ # ---------- Main ----------
159+
160+ def main ():
161+ parser = argparse .ArgumentParser (description = 'Swin Transformer CIFAR10 Example' )
162+ parser .add_argument ('--batch-size' , type = int , default = 64 )
163+ parser .add_argument ('--test-batch-size' , type = int , default = 1000 )
164+ parser .add_argument ('--epochs' , type = int , default = 10 )
165+ parser .add_argument ('--lr' , type = float , default = 0.01 )
166+ parser .add_argument ('--gamma' , type = float , default = 0.7 )
167+ parser .add_argument ('--dry-run' , action = 'store_true' )
168+ parser .add_argument ('--seed' , type = int , default = 42 )
169+ parser .add_argument ('--log-interval' , type = int , default = 10 )
170+ parser .add_argument ('--save-model' , action = 'store_true' )
171+ args = parser .parse_args ()
172+
173+ use_accel = torch .accelerator .is_available ()
174+ device = torch .accelerator .current_accelerator () if use_accel else torch .device ("cpu" )
175+ print (f"Using device: { device } " )
176+
177+ torch .manual_seed (args .seed )
178+
179+ transform = transforms .Compose ([
180+ transforms .ToTensor (),
181+ transforms .Normalize ((0.5 ,), (0.5 ,))
182+ ])
183+
184+ train_loader = torch .utils .data .DataLoader (
185+ datasets .CIFAR10 ('../data' , train = True , download = True , transform = transform ),
186+ batch_size = args .batch_size , shuffle = True )
187+
188+ test_loader = torch .utils .data .DataLoader (
189+ datasets .CIFAR10 ('../data' , train = False , transform = transform ),
190+ batch_size = args .test_batch_size , shuffle = False )
191+
192+ model = SwinTinyNet ().to (device )
193+ optimizer = optim .Adam (model .parameters (), lr = args .lr )
194+ scheduler = StepLR (optimizer , step_size = 3 , gamma = args .gamma )
195+
196+ for epoch in range (1 , args .epochs + 1 ):
197+ train (args , model , device , train_loader , optimizer , epoch )
198+ test (args , model , device , test_loader )
199+ scheduler .step ()
200+
201+ if args .save_model :
202+ torch .save (model .state_dict (), "swin_cifar10.pt" )
203+ main ()
0 commit comments