diff --git a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py index 193327379d..bab8b4349e 100644 --- a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py +++ b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py @@ -20,15 +20,20 @@ import urllib import re from collections import OrderedDict - import numpy as np from ..config import PathField, StringField, DictField, NumberField, ListField, BoolField +from ..utils import UnsupportedPackage from .launcher import Launcher +try: + import transformers +except ImportError as transformers_error: + transformers = UnsupportedPackage('transformers', transformers_error.msg) +CLASS_REGEX = r'(?:\w+)' MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)' DEVICE_REGEX = r'(?Pcpu$|cuda)?' CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$' - +SCALAR_INPUTS = ('input_ids', 'input_mask', 'segment_ids', 'attention_mask', 'token_type_ids') class PyTorchLauncher(Launcher): __provider__ = 'pytorch' @@ -67,6 +72,9 @@ def parameters(cls): 'torch_compile_kwargs': DictField( key_type=str, validate_values=False, optional=True, default={}, description="dictionary of keyword arguments passed to torch.compile" + ), + 'transformers_class': StringField( + optional=True, regex=CLASS_REGEX, description='Transformers class name to load pre-trained module.' ) }) return parameters @@ -84,10 +92,11 @@ def __init__(self, config_entry: dict, *args, **kwargs): self.validate_config(config_entry) self.use_torch_compile = config_entry.get('use_torch_compile', False) self.compile_kwargs = config_entry.get('torch_compile_kwargs', {}) + self.tranformers_class = config_entry.get('transformers_class', None) backend = self.compile_kwargs.get('backend', None) if self.use_torch_compile and backend == 'openvino': try: - import openvino.torch # pylint: disable=C0415, W0611 + importlib.import_module('openvino.torch') # pylint: disable=C0415, W0611 except ImportError as import_error: raise ValueError("torch.compile is supported from OpenVINO 2023.1\n{}".format( import_error.msg)) from import_error @@ -95,18 +104,28 @@ def __init__(self, config_entry: dict, *args, **kwargs): module_kwargs = config_entry.get("module_kwargs", {}) self.device = self.get_value_from_config('device') self.cuda = 'cuda' in self.device + checkpoint = config_entry.get('checkpoint') if checkpoint is None: checkpoint = config_entry.get('checkpoint_url') - self.module = self.load_module( - config_entry['module'], - module_args, - module_kwargs, - checkpoint, - config_entry.get('state_key'), - config_entry.get("python_path"), - config_entry.get("init_method") - ) + + python_path = config_entry.get("python_path") + + if self.tranformers_class: + self.module = self.load_tranformers_module( + config_entry['module'], python_path + ) + else: + + self.module = self.load_module( + config_entry['module'], + module_args, + module_kwargs, + checkpoint, + config_entry.get('state_key'), + python_path, + config_entry.get("init_method") + ) self._batch = self.get_value_from_config('batch') # torch modules does not have input information @@ -142,7 +161,7 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st model_cls = module_parts[-1] model_path = ".".join(module_parts[:-1]) with append_to_path(python_path): - model_cls = importlib.import_module(model_path).__getattribute__(model_cls) + model_cls = getattr(importlib.import_module(model_path), model_cls) module = model_cls(*module_args, **module_kwargs) if init_method is not None: if hasattr(model_cls, init_method): @@ -153,7 +172,7 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st if checkpoint: if isinstance(checkpoint, str) and re.match(CHECKPOINT_URL_REGEX, checkpoint): - checkpoint = urllib.request.urlretrieve(checkpoint)[0] + checkpoint = urllib.request.urlretrieve(checkpoint)[0] # nosec B310 # disable urllib-urlopen check checkpoint = self._torch.load( checkpoint, map_location=None if self.cuda else self._torch.device('cpu') ) @@ -161,23 +180,45 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st if all(key.startswith('module.') for key in state): module = self._torch.nn.DataParallel(module) module.load_state_dict(state, strict=False) - module.to('cuda' if self.cuda else 'cpu') - module.eval() - if self.use_torch_compile: - if hasattr(model_cls, 'compile'): - module.compile() - module = self._torch.compile(module, **self.compile_kwargs) + return self.prepare_module(module, model_cls) + + def load_tranformers_module(self, pretrained_name, python_path): + with append_to_path(python_path): + if isinstance(transformers, UnsupportedPackage): + transformers.raise_error(self.__class__.__name__) + + model_class = getattr(transformers, self.tranformers_class) + pretrained_model = python_path if python_path else pretrained_name + module = model_class.from_pretrained(pretrained_model) + + return self.prepare_module(module, model_class) + + def prepare_module(self, module, model_class): + module.to('cuda' if self.cuda else 'cpu') + module.eval() + + if self.use_torch_compile: + if hasattr(model_class, 'compile'): + module.compile() + module = self._torch.compile(module, **self.compile_kwargs) + + return module - return module def _convert_to_tensor(self, value, precision): if isinstance(value, self._torch.Tensor): return value - return self._torch.from_numpy(value.astype(np.float32 if not precision else precision)).to(self.device) + if precision is None: + precision = np.float32 + + return self._torch.from_numpy(value.astype(precision)).to(self.device) def fit_to_input(self, data, layer_name, layout, precision, template=None): + if precision is None and layer_name in SCALAR_INPUTS: + precision = np.int64 + if layer_name == 'input' and isinstance(data[0], dict): tensor_dict = {} for key, val in data[0].items(): @@ -191,11 +232,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None): return tensor_dict - if layout is not None: + data_shape = np.shape(data) + + if layout is not None and len(data_shape) == len(layout): data = np.transpose(data, layout) - tensor = self._torch.from_numpy(data.astype(np.float32 if not precision else precision)) - tensor = tensor.to(self.device) - return tensor + else: + data = np.array(data) + + return self._convert_to_tensor(data, precision) def _convert_to_numpy(self, input_dict): numpy_dict = {} @@ -206,18 +250,27 @@ def _convert_to_numpy(self, input_dict): numpy_dict[key] = value return numpy_dict + + def forward(self, outputs): + if hasattr(outputs, 'logits') and 'logits' in self.output_names: + return {'logits': outputs.logits} + if hasattr(outputs, 'last_hidden_state') and 'last_hidden_state' in self.output_names: + return {'last_hidden_state': outputs.last_hidden_state} + return list(outputs) + def predict(self, inputs, metadata=None, **kwargs): results = [] with self._torch.no_grad(): for batch_input in inputs: - if metadata[0].get('input_is_dict_type'): + if metadata[0].get('input_is_dict_type') or (isinstance(batch_input, dict) and 'input' in batch_input): outputs = self.module(batch_input['input']) else: - outputs = list(self.module(*batch_input.values())) + outputs = self.module(**batch_input) + for meta_ in metadata: meta_['input_shape'] = {key: list(data.shape) for key, data in batch_input.items()} - if metadata[0].get('output_is_dict_type'): + if metadata[0].get('output_is_dict_type') or isinstance(outputs, dict): result_dict = self._convert_to_numpy(outputs) else: result_dict = { diff --git a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md index acfbc5dfaf..d32f5c560c 100644 --- a/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md +++ b/tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md @@ -17,6 +17,7 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers * `batch` - batch size for running model (Optional, default 1). * `use_torch_compile` - boolean, use torch.compile to optimize the module code (Optional, default `False`) * `torch_compile_kwargs` - dictionary of keyword arguments to pass to torch.compile (Optional, default `{}`) +* `transformers_class` - transformers class name to load pre-trained model with `module` name. (Optional). In turn if you model has several inputs you need to specify them in config, using specific parameter: `inputs`. Each input description should has following info: