99from sparseml .pytorch .optim import ScheduledModifierManager
1010from sparseml .pytorch .utils import SparsificationGroupLogger
1111from sparseml .pytorch .utils import GradSampler
12+ from sparseml .pytorch .sparsification .quantization import QuantizationModifier
1213import torchvision .transforms .functional as F
1314
1415from utils .torch_utils import is_parallel
@@ -51,7 +52,16 @@ def check_download_sparsezoo_weights(path):
5152
5253
5354class SparseMLWrapper (object ):
54- def __init__ (self , model , checkpoint_recipe , train_recipe , steps_per_epoch = - 1 , one_shot = False ):
55+ def __init__ (
56+ self ,
57+ model ,
58+ checkpoint_recipe ,
59+ train_recipe ,
60+ train_mode = False ,
61+ epoch = - 1 ,
62+ steps_per_epoch = - 1 ,
63+ one_shot = False ,
64+ ):
5565 self .enabled = bool (train_recipe )
5666 self .model = model .module if is_parallel (model ) else model
5767 self .checkpoint_manager = ScheduledModifierManager .from_yaml (checkpoint_recipe ) if checkpoint_recipe else None
@@ -62,21 +72,47 @@ def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, o
6272 self .one_shot = one_shot
6373 self .train_recipe = train_recipe
6474
65- if self .one_shot :
66- self ._apply_one_shot ()
67-
68- def state_dict (self ):
69- manager = (ScheduledModifierManager .compose_staged (self .checkpoint_manager , self .manager )
70- if self .checkpoint_manager and self .enabled else self .manager )
75+ self .apply_checkpoint_structure (train_mode , epoch , one_shot )
7176
72- return {
73- 'recipe' : str (manager ) if self .enabled else None ,
74- }
77+ def state_dict (self , final_epoch ):
78+ if self .enabled or self .checkpoint_manager :
79+ compose_recipes = self .checkpoint_manager and self .enabled and final_epoch
80+ return {
81+ 'checkpoint_recipe' : str (ScheduledModifierManager .compose_staged (self .checkpoint_manager , self .manager ))
82+ if compose_recipes else str (self .checkpoint_manager ),
83+ 'train_recipe' : str (self .manager ) if not final_epoch else None
84+ }
85+ else :
86+ return {
87+ 'checkpoint_recipe' : None ,
88+ 'train_recipe' : None
89+ }
7590
76- def apply_checkpoint_structure (self ):
91+ def apply_checkpoint_structure (self , train_mode , epoch , one_shot = False ):
7792 if self .checkpoint_manager :
93+ # if checkpoint recipe has a QAT modifier and this is a transfer learning
94+ # run then remove the QAT modifier from the manager
95+ if train_mode :
96+ qat_idx = next ((
97+ idx for idx , mod in enumerate (self .checkpoint_manager .modifiers )
98+ if isinstance (mod , QuantizationModifier )), - 1
99+ )
100+ if qat_idx >= 0 :
101+ _ = self .checkpoint_manager .modifiers .pop (qat_idx )
102+
78103 self .checkpoint_manager .apply_structure (self .model , math .inf )
79104
105+ if train_mode and epoch > 0 and self .enabled :
106+ self .manager .apply_structure (self .model , epoch )
107+ elif one_shot :
108+ if self .enabled :
109+ self .manager .apply (self .model )
110+ _LOGGER .info (f"Applied recipe { self .train_recipe } in one-shot manner" )
111+ else :
112+ _LOGGER .info (f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
113+ f"{ self .train_recipe } "
114+ )
115+
80116 def initialize (
81117 self ,
82118 start_epoch ,
@@ -144,9 +180,9 @@ def check_lr_override(self, scheduler, rank):
144180 def check_epoch_override (self , epochs , rank ):
145181 # Override num epochs if recipe explicitly modifies epoch range
146182 if self .enabled and self .manager .epoch_modifiers and self .manager .max_epochs :
183+ epochs = self .manager .max_epochs or epochs # override num_epochs
147184 if rank in [0 ,- 1 ]:
148185 self .logger .info (f'Overriding number of epochs from SparseML manager to { epochs } ' )
149- epochs = self .manager .max_epochs + self .start_epoch or epochs # override num_epochs
150186
151187 return epochs
152188
@@ -195,15 +231,6 @@ def dataloader():
195231 imgs = nn .functional .interpolate (imgs , size = ns , mode = 'bilinear' , align_corners = False )
196232 yield [imgs ], {}, targets
197233 return dataloader
198-
199- def _apply_one_shot (self ):
200- if self .manager is not None :
201- self .manager .apply (self .model )
202- _LOGGER .info (f"Applied recipe { self .train_recipe } in one-shot manner" )
203- else :
204- _LOGGER .info (f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
205- f"{ self .train_recipe } "
206- )
207234
208235 def save_sample_inputs_outputs (
209236 self ,
0 commit comments