Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.

Commit d5807bf

Browse files
authored
Fixes for quant transfer learn (#76) (#77)
1 parent a065b8d commit d5807bf

File tree

3 files changed

+62
-29
lines changed

3 files changed

+62
-29
lines changed

export.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
432432
except Exception as e:
433433
LOGGER.info(f'\n{prefix} export failure: {e}')
434434

435-
def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
435+
def create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
436436
pickle = not sparseml_wrapper.qat_active(math.inf if epoch <0 else epoch) # qat does not support pickled exports
437437
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
438438
yaml = ckpt_model.yaml
@@ -445,7 +445,7 @@ def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
445445
'yaml': yaml,
446446
'hyp': model.hyp,
447447
**ema.state_dict(pickle),
448-
**sparseml_wrapper.state_dict(),
448+
**sparseml_wrapper.state_dict(final_epoch),
449449
**kwargs}
450450

451451
def load_checkpoint(
@@ -469,6 +469,10 @@ def load_checkpoint(
469469
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
470470
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
471471
else weights, map_location="cpu") # load checkpoint
472+
473+
# temporary fix until SparseML and ZooModels are updated
474+
ckpt['checkpoint_recipe'] = ckpt.get('recipe') or ckpt.get('checkpoint_recipe')
475+
472476
pickled = isinstance(ckpt['model'], nn.Module)
473477
train_type = type_ == 'train'
474478
ensemble_type = type_ == 'ensemble'
@@ -500,21 +504,22 @@ def load_checkpoint(
500504
# load sparseml recipe for applying pruning and quantization
501505
checkpoint_recipe = train_recipe = None
502506
if resume:
503-
train_recipe = ckpt.get('recipe')
504-
elif recipe or ckpt.get('recipe'):
505-
train_recipe, checkpoint_recipe = recipe, ckpt.get('recipe')
507+
train_recipe, checkpoint_recipe = ckpt.get('train_recipe'), ckpt.get('checkpoint_recipe')
508+
elif recipe or ckpt.get('checkpoint_recipe'):
509+
train_recipe, checkpoint_recipe = recipe, ckpt.get('checkpoint_recipe')
506510

507511
sparseml_wrapper = SparseMLWrapper(
508512
model.model if val_type else model,
509513
checkpoint_recipe,
510514
train_recipe,
515+
train_mode=train_type,
516+
epoch=ckpt['epoch'],
511517
one_shot=one_shot,
512518
steps_per_epoch=max_train_steps,
513519
)
514520
exclude_anchors = not ensemble_type and (cfg or hyp.get('anchors')) and not resume
515521
loaded = False
516522

517-
sparseml_wrapper.apply_checkpoint_structure()
518523
if train_type:
519524
# intialize the recipe for training and restore the weights before if no quantized weights
520525
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])

train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
141141
model,
142142
None,
143143
opt.recipe,
144+
train_mode=True,
144145
steps_per_epoch=opt.max_train_steps,
145146
one_shot=opt.one_shot,
146147
)
@@ -314,7 +315,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
314315
"date": datetime.now().isoformat(),
315316
}
316317
ckpt = create_checkpoint(
317-
-1, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
318+
-1, True, model, optimizer, ema, sparseml_wrapper, **ckpt_extras
318319
)
319320
one_shot_checkpoint_name = w / "checkpoint-one-shot.pt"
320321
torch.save(ckpt, one_shot_checkpoint_name)
@@ -486,7 +487,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
486487
'best_fitness': best_fitness,
487488
'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
488489
'date': datetime.now().isoformat()}
489-
ckpt = create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)
490+
ckpt = create_checkpoint(epoch, final_epoch, model, optimizer, ema, sparseml_wrapper, **ckpt_extras)
490491

491492
# Save last, best and delete
492493
torch.save(ckpt, last)

utils/sparse.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sparseml.pytorch.optim import ScheduledModifierManager
1010
from sparseml.pytorch.utils import SparsificationGroupLogger
1111
from sparseml.pytorch.utils import GradSampler
12+
from sparseml.pytorch.sparsification.quantization import QuantizationModifier
1213
import torchvision.transforms.functional as F
1314

1415
from utils.torch_utils import is_parallel
@@ -51,7 +52,16 @@ def check_download_sparsezoo_weights(path):
5152

5253

5354
class SparseMLWrapper(object):
54-
def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, one_shot=False):
55+
def __init__(
56+
self,
57+
model,
58+
checkpoint_recipe,
59+
train_recipe,
60+
train_mode=False,
61+
epoch=-1,
62+
steps_per_epoch=-1,
63+
one_shot=False,
64+
):
5565
self.enabled = bool(train_recipe)
5666
self.model = model.module if is_parallel(model) else model
5767
self.checkpoint_manager = ScheduledModifierManager.from_yaml(checkpoint_recipe) if checkpoint_recipe else None
@@ -62,21 +72,47 @@ def __init__(self, model, checkpoint_recipe, train_recipe, steps_per_epoch=-1, o
6272
self.one_shot = one_shot
6373
self.train_recipe = train_recipe
6474

65-
if self.one_shot:
66-
self._apply_one_shot()
67-
68-
def state_dict(self):
69-
manager = (ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager)
70-
if self.checkpoint_manager and self.enabled else self.manager)
75+
self.apply_checkpoint_structure(train_mode, epoch, one_shot)
7176

72-
return {
73-
'recipe': str(manager) if self.enabled else None,
74-
}
77+
def state_dict(self, final_epoch):
78+
if self.enabled or self.checkpoint_manager:
79+
compose_recipes = self.checkpoint_manager and self.enabled and final_epoch
80+
return {
81+
'checkpoint_recipe': str(ScheduledModifierManager.compose_staged(self.checkpoint_manager, self.manager))
82+
if compose_recipes else str(self.checkpoint_manager),
83+
'train_recipe': str(self.manager) if not final_epoch else None
84+
}
85+
else:
86+
return {
87+
'checkpoint_recipe': None,
88+
'train_recipe': None
89+
}
7590

76-
def apply_checkpoint_structure(self):
91+
def apply_checkpoint_structure(self, train_mode, epoch, one_shot=False):
7792
if self.checkpoint_manager:
93+
# if checkpoint recipe has a QAT modifier and this is a transfer learning
94+
# run then remove the QAT modifier from the manager
95+
if train_mode:
96+
qat_idx = next((
97+
idx for idx, mod in enumerate(self.checkpoint_manager.modifiers)
98+
if isinstance(mod, QuantizationModifier)), -1
99+
)
100+
if qat_idx >= 0:
101+
_ = self.checkpoint_manager.modifiers.pop(qat_idx)
102+
78103
self.checkpoint_manager.apply_structure(self.model, math.inf)
79104

105+
if train_mode and epoch > 0 and self.enabled:
106+
self.manager.apply_structure(self.model, epoch)
107+
elif one_shot:
108+
if self.enabled:
109+
self.manager.apply(self.model)
110+
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
111+
else:
112+
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
113+
f"{self.train_recipe}"
114+
)
115+
80116
def initialize(
81117
self,
82118
start_epoch,
@@ -144,9 +180,9 @@ def check_lr_override(self, scheduler, rank):
144180
def check_epoch_override(self, epochs, rank):
145181
# Override num epochs if recipe explicitly modifies epoch range
146182
if self.enabled and self.manager.epoch_modifiers and self.manager.max_epochs:
183+
epochs = self.manager.max_epochs or epochs # override num_epochs
147184
if rank in [0,-1]:
148185
self.logger.info(f'Overriding number of epochs from SparseML manager to {epochs}')
149-
epochs = self.manager.max_epochs + self.start_epoch or epochs # override num_epochs
150186

151187
return epochs
152188

@@ -195,15 +231,6 @@ def dataloader():
195231
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
196232
yield [imgs], {}, targets
197233
return dataloader
198-
199-
def _apply_one_shot(self):
200-
if self.manager is not None:
201-
self.manager.apply(self.model)
202-
_LOGGER.info(f"Applied recipe {self.train_recipe} in one-shot manner")
203-
else:
204-
_LOGGER.info(f"Training recipe for one-shot application not recognized by the manager. Got recipe: "
205-
f"{self.train_recipe}"
206-
)
207234

208235
def save_sample_inputs_outputs(
209236
self,

0 commit comments

Comments
 (0)