diff --git a/models/public/regnetx-3.2gf/model.py b/models/public/regnetx-3.2gf/model.py index a2f446a2000..d43dfa8395b 100644 --- a/models/public/regnetx-3.2gf/model.py +++ b/models/public/regnetx-3.2gf/model.py @@ -12,11 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pycls.core.checkpoint +import torch +import pycls.core.config import pycls.models.model_zoo +from pycls.core.checkpoint import unwrap_model def regnet(config_path, weights_path): pycls.core.config.cfg.merge_from_file(config_path) model = pycls.models.model_zoo.RegNet() - pycls.core.checkpoint.load_checkpoint(weights_path, model) + checkpoint = torch.load(weights_path, map_location="cpu", weights_only=False) + test_err = checkpoint.get("test_err", 100) + ema_err = checkpoint.get("ema_err", 100) + ema_state = "ema_state" if "ema_state" in checkpoint else "model_state" + best_state = "model_state" if test_err <= ema_err else ema_state + unwrap_model(model).load_state_dict(checkpoint[best_state]) return model diff --git a/tools/model_tools/src/omz_tools/internal_scripts/pytorch_to_onnx.py b/tools/model_tools/src/omz_tools/internal_scripts/pytorch_to_onnx.py index e0625d42f48..775c48310b9 100644 --- a/tools/model_tools/src/omz_tools/internal_scripts/pytorch_to_onnx.py +++ b/tools/model_tools/src/omz_tools/internal_scripts/pytorch_to_onnx.py @@ -144,7 +144,7 @@ def load_model(model_name, weights, model_paths, module_name, model_params): try: if weights: - model.load_state_dict(torch.load(weights, map_location='cpu')) + model.load_state_dict(torch.load(weights, map_location='cpu', weights_only=False)) except RuntimeError as err: print('ERROR: Weights from {} cannot be loaded for model {}! Check matching between model and weights'.format( weights, model_name))