Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 82 additions & 29 deletions tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
import urllib
import re
from collections import OrderedDict

import numpy as np
from ..config import PathField, StringField, DictField, NumberField, ListField, BoolField
from ..utils import UnsupportedPackage
from .launcher import Launcher
try:
import transformers
except ImportError as transformers_error:
transformers = UnsupportedPackage('transformers', transformers_error.msg)

CLASS_REGEX = r'(?:\w+)'
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'

SCALAR_INPUTS = ('input_ids', 'input_mask', 'segment_ids', 'attention_mask', 'token_type_ids')

class PyTorchLauncher(Launcher):
__provider__ = 'pytorch'
Expand Down Expand Up @@ -67,6 +72,9 @@ def parameters(cls):
'torch_compile_kwargs': DictField(
key_type=str, validate_values=False, optional=True, default={},
description="dictionary of keyword arguments passed to torch.compile"
),
'transformers_class': StringField(
optional=True, regex=CLASS_REGEX, description='Transformers class name to load pre-trained module.'
)
})
return parameters
Expand All @@ -84,29 +92,40 @@ def __init__(self, config_entry: dict, *args, **kwargs):
self.validate_config(config_entry)
self.use_torch_compile = config_entry.get('use_torch_compile', False)
self.compile_kwargs = config_entry.get('torch_compile_kwargs', {})
self.tranformers_class = config_entry.get('transformers_class', None)
backend = self.compile_kwargs.get('backend', None)
if self.use_torch_compile and backend == 'openvino':
try:
import openvino.torch # pylint: disable=C0415, W0611
importlib.import_module('openvino.torch') # pylint: disable=C0415, W0611
except ImportError as import_error:
raise ValueError("torch.compile is supported from OpenVINO 2023.1\n{}".format(
import_error.msg)) from import_error
module_args = config_entry.get("module_args", ())
module_kwargs = config_entry.get("module_kwargs", {})
self.device = self.get_value_from_config('device')
self.cuda = 'cuda' in self.device

checkpoint = config_entry.get('checkpoint')
if checkpoint is None:
checkpoint = config_entry.get('checkpoint_url')
self.module = self.load_module(
config_entry['module'],
module_args,
module_kwargs,
checkpoint,
config_entry.get('state_key'),
config_entry.get("python_path"),
config_entry.get("init_method")
)

python_path = config_entry.get("python_path")

if self.tranformers_class:
self.module = self.load_tranformers_module(
config_entry['module'], python_path
)
else:

self.module = self.load_module(
config_entry['module'],
module_args,
module_kwargs,
checkpoint,
config_entry.get('state_key'),
python_path,
config_entry.get("init_method")
)

self._batch = self.get_value_from_config('batch')
# torch modules does not have input information
Expand Down Expand Up @@ -142,7 +161,7 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
model_cls = module_parts[-1]
model_path = ".".join(module_parts[:-1])
with append_to_path(python_path):
model_cls = importlib.import_module(model_path).__getattribute__(model_cls)
model_cls = getattr(importlib.import_module(model_path), model_cls)
module = model_cls(*module_args, **module_kwargs)
if init_method is not None:
if hasattr(model_cls, init_method):
Expand All @@ -153,31 +172,53 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st

if checkpoint:
if isinstance(checkpoint, str) and re.match(CHECKPOINT_URL_REGEX, checkpoint):
checkpoint = urllib.request.urlretrieve(checkpoint)[0]
checkpoint = urllib.request.urlretrieve(checkpoint)[0] # nosec B310 # disable urllib-urlopen check
checkpoint = self._torch.load(
checkpoint, map_location=None if self.cuda else self._torch.device('cpu')
)
state = checkpoint if not state_key else checkpoint[state_key]
if all(key.startswith('module.') for key in state):
module = self._torch.nn.DataParallel(module)
module.load_state_dict(state, strict=False)
module.to('cuda' if self.cuda else 'cpu')
module.eval()

if self.use_torch_compile:
if hasattr(model_cls, 'compile'):
module.compile()
module = self._torch.compile(module, **self.compile_kwargs)
return self.prepare_module(module, model_cls)

def load_tranformers_module(self, pretrained_name, python_path):
with append_to_path(python_path):
if isinstance(transformers, UnsupportedPackage):
transformers.raise_error(self.__class__.__name__)

model_class = getattr(transformers, self.tranformers_class)
pretrained_model = python_path if python_path else pretrained_name
module = model_class.from_pretrained(pretrained_model)

return self.prepare_module(module, model_class)

def prepare_module(self, module, model_class):
module.to('cuda' if self.cuda else 'cpu')
module.eval()

if self.use_torch_compile:
if hasattr(model_class, 'compile'):
module.compile()
module = self._torch.compile(module, **self.compile_kwargs)

return module

return module

def _convert_to_tensor(self, value, precision):
if isinstance(value, self._torch.Tensor):
return value
return self._torch.from_numpy(value.astype(np.float32 if not precision else precision)).to(self.device)
if precision is None:
precision = np.float32

return self._torch.from_numpy(value.astype(precision)).to(self.device)

def fit_to_input(self, data, layer_name, layout, precision, template=None):

if precision is None and layer_name in SCALAR_INPUTS:
precision = np.int64

if layer_name == 'input' and isinstance(data[0], dict):
tensor_dict = {}
for key, val in data[0].items():
Expand All @@ -191,11 +232,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):

return tensor_dict

if layout is not None:
data_shape = np.shape(data)

if layout is not None and len(data_shape) == len(layout):
data = np.transpose(data, layout)
tensor = self._torch.from_numpy(data.astype(np.float32 if not precision else precision))
tensor = tensor.to(self.device)
return tensor
else:
data = np.array(data)

return self._convert_to_tensor(data, precision)

def _convert_to_numpy(self, input_dict):
numpy_dict = {}
Expand All @@ -206,18 +250,27 @@ def _convert_to_numpy(self, input_dict):
numpy_dict[key] = value
return numpy_dict


def forward(self, outputs):
if hasattr(outputs, 'logits') and 'logits' in self.output_names:
return {'logits': outputs.logits}
if hasattr(outputs, 'last_hidden_state') and 'last_hidden_state' in self.output_names:
return {'last_hidden_state': outputs.last_hidden_state}
return list(outputs)

def predict(self, inputs, metadata=None, **kwargs):
results = []
with self._torch.no_grad():
for batch_input in inputs:
if metadata[0].get('input_is_dict_type'):
if metadata[0].get('input_is_dict_type') or (isinstance(batch_input, dict) and 'input' in batch_input):
outputs = self.module(batch_input['input'])
else:
outputs = list(self.module(*batch_input.values()))
outputs = self.module(**batch_input)

for meta_ in metadata:
meta_['input_shape'] = {key: list(data.shape) for key, data in batch_input.items()}

if metadata[0].get('output_is_dict_type'):
if metadata[0].get('output_is_dict_type') or isinstance(outputs, dict):
result_dict = self._convert_to_numpy(outputs)
else:
result_dict = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers
* `batch` - batch size for running model (Optional, default 1).
* `use_torch_compile` - boolean, use torch.compile to optimize the module code (Optional, default `False`)
* `torch_compile_kwargs` - dictionary of keyword arguments to pass to torch.compile (Optional, default `{}`)
* `transformers_class` - transformers class name to load pre-trained model with `module` name. (Optional).

In turn if you model has several inputs you need to specify them in config, using specific parameter: `inputs`.
Each input description should has following info:
Expand Down
Loading