Skip to content

Commit ce823fc

Browse files
authored
Pytorch launcher add support for transformers loaded models (#4011)
* Add support for native pytroch models with transformers class * Update pytorch launcher readme * Import transformers with try except * Remove double space pytorch_launcher.py * Add some fixes pytorch_launcher.py
1 parent 2378e96 commit ce823fc

File tree

2 files changed

+83
-29
lines changed

2 files changed

+83
-29
lines changed

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py

Lines changed: 82 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@
2020
import urllib
2121
import re
2222
from collections import OrderedDict
23-
2423
import numpy as np
2524
from ..config import PathField, StringField, DictField, NumberField, ListField, BoolField
25+
from ..utils import UnsupportedPackage
2626
from .launcher import Launcher
27+
try:
28+
import transformers
29+
except ImportError as transformers_error:
30+
transformers = UnsupportedPackage('transformers', transformers_error.msg)
2731

32+
CLASS_REGEX = r'(?:\w+)'
2833
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2934
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
3035
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
31-
36+
SCALAR_INPUTS = ('input_ids', 'input_mask', 'segment_ids', 'attention_mask', 'token_type_ids')
3237

3338
class PyTorchLauncher(Launcher):
3439
__provider__ = 'pytorch'
@@ -67,6 +72,9 @@ def parameters(cls):
6772
'torch_compile_kwargs': DictField(
6873
key_type=str, validate_values=False, optional=True, default={},
6974
description="dictionary of keyword arguments passed to torch.compile"
75+
),
76+
'transformers_class': StringField(
77+
optional=True, regex=CLASS_REGEX, description='Transformers class name to load pre-trained module.'
7078
)
7179
})
7280
return parameters
@@ -84,29 +92,40 @@ def __init__(self, config_entry: dict, *args, **kwargs):
8492
self.validate_config(config_entry)
8593
self.use_torch_compile = config_entry.get('use_torch_compile', False)
8694
self.compile_kwargs = config_entry.get('torch_compile_kwargs', {})
95+
self.tranformers_class = config_entry.get('transformers_class', None)
8796
backend = self.compile_kwargs.get('backend', None)
8897
if self.use_torch_compile and backend == 'openvino':
8998
try:
90-
import openvino.torch # pylint: disable=C0415, W0611
99+
importlib.import_module('openvino.torch') # pylint: disable=C0415, W0611
91100
except ImportError as import_error:
92101
raise ValueError("torch.compile is supported from OpenVINO 2023.1\n{}".format(
93102
import_error.msg)) from import_error
94103
module_args = config_entry.get("module_args", ())
95104
module_kwargs = config_entry.get("module_kwargs", {})
96105
self.device = self.get_value_from_config('device')
97106
self.cuda = 'cuda' in self.device
107+
98108
checkpoint = config_entry.get('checkpoint')
99109
if checkpoint is None:
100110
checkpoint = config_entry.get('checkpoint_url')
101-
self.module = self.load_module(
102-
config_entry['module'],
103-
module_args,
104-
module_kwargs,
105-
checkpoint,
106-
config_entry.get('state_key'),
107-
config_entry.get("python_path"),
108-
config_entry.get("init_method")
109-
)
111+
112+
python_path = config_entry.get("python_path")
113+
114+
if self.tranformers_class:
115+
self.module = self.load_tranformers_module(
116+
config_entry['module'], python_path
117+
)
118+
else:
119+
120+
self.module = self.load_module(
121+
config_entry['module'],
122+
module_args,
123+
module_kwargs,
124+
checkpoint,
125+
config_entry.get('state_key'),
126+
python_path,
127+
config_entry.get("init_method")
128+
)
110129

111130
self._batch = self.get_value_from_config('batch')
112131
# torch modules does not have input information
@@ -142,7 +161,7 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
142161
model_cls = module_parts[-1]
143162
model_path = ".".join(module_parts[:-1])
144163
with append_to_path(python_path):
145-
model_cls = importlib.import_module(model_path).__getattribute__(model_cls)
164+
model_cls = getattr(importlib.import_module(model_path), model_cls)
146165
module = model_cls(*module_args, **module_kwargs)
147166
if init_method is not None:
148167
if hasattr(model_cls, init_method):
@@ -153,31 +172,53 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
153172

154173
if checkpoint:
155174
if isinstance(checkpoint, str) and re.match(CHECKPOINT_URL_REGEX, checkpoint):
156-
checkpoint = urllib.request.urlretrieve(checkpoint)[0]
175+
checkpoint = urllib.request.urlretrieve(checkpoint)[0] # nosec B310 # disable urllib-urlopen check
157176
checkpoint = self._torch.load(
158177
checkpoint, map_location=None if self.cuda else self._torch.device('cpu')
159178
)
160179
state = checkpoint if not state_key else checkpoint[state_key]
161180
if all(key.startswith('module.') for key in state):
162181
module = self._torch.nn.DataParallel(module)
163182
module.load_state_dict(state, strict=False)
164-
module.to('cuda' if self.cuda else 'cpu')
165-
module.eval()
166183

167-
if self.use_torch_compile:
168-
if hasattr(model_cls, 'compile'):
169-
module.compile()
170-
module = self._torch.compile(module, **self.compile_kwargs)
184+
return self.prepare_module(module, model_cls)
185+
186+
def load_tranformers_module(self, pretrained_name, python_path):
187+
with append_to_path(python_path):
188+
if isinstance(transformers, UnsupportedPackage):
189+
transformers.raise_error(self.__class__.__name__)
190+
191+
model_class = getattr(transformers, self.tranformers_class)
192+
pretrained_model = python_path if python_path else pretrained_name
193+
module = model_class.from_pretrained(pretrained_model)
194+
195+
return self.prepare_module(module, model_class)
196+
197+
def prepare_module(self, module, model_class):
198+
module.to('cuda' if self.cuda else 'cpu')
199+
module.eval()
200+
201+
if self.use_torch_compile:
202+
if hasattr(model_class, 'compile'):
203+
module.compile()
204+
module = self._torch.compile(module, **self.compile_kwargs)
205+
206+
return module
171207

172-
return module
173208

174209
def _convert_to_tensor(self, value, precision):
175210
if isinstance(value, self._torch.Tensor):
176211
return value
177-
return self._torch.from_numpy(value.astype(np.float32 if not precision else precision)).to(self.device)
212+
if precision is None:
213+
precision = np.float32
214+
215+
return self._torch.from_numpy(value.astype(precision)).to(self.device)
178216

179217
def fit_to_input(self, data, layer_name, layout, precision, template=None):
180218

219+
if precision is None and layer_name in SCALAR_INPUTS:
220+
precision = np.int64
221+
181222
if layer_name == 'input' and isinstance(data[0], dict):
182223
tensor_dict = {}
183224
for key, val in data[0].items():
@@ -191,11 +232,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
191232

192233
return tensor_dict
193234

194-
if layout is not None:
235+
data_shape = np.shape(data)
236+
237+
if layout is not None and len(data_shape) == len(layout):
195238
data = np.transpose(data, layout)
196-
tensor = self._torch.from_numpy(data.astype(np.float32 if not precision else precision))
197-
tensor = tensor.to(self.device)
198-
return tensor
239+
else:
240+
data = np.array(data)
241+
242+
return self._convert_to_tensor(data, precision)
199243

200244
def _convert_to_numpy(self, input_dict):
201245
numpy_dict = {}
@@ -206,18 +250,27 @@ def _convert_to_numpy(self, input_dict):
206250
numpy_dict[key] = value
207251
return numpy_dict
208252

253+
254+
def forward(self, outputs):
255+
if hasattr(outputs, 'logits') and 'logits' in self.output_names:
256+
return {'logits': outputs.logits}
257+
if hasattr(outputs, 'last_hidden_state') and 'last_hidden_state' in self.output_names:
258+
return {'last_hidden_state': outputs.last_hidden_state}
259+
return list(outputs)
260+
209261
def predict(self, inputs, metadata=None, **kwargs):
210262
results = []
211263
with self._torch.no_grad():
212264
for batch_input in inputs:
213-
if metadata[0].get('input_is_dict_type'):
265+
if metadata[0].get('input_is_dict_type') or (isinstance(batch_input, dict) and 'input' in batch_input):
214266
outputs = self.module(batch_input['input'])
215267
else:
216-
outputs = list(self.module(*batch_input.values()))
268+
outputs = self.module(**batch_input)
269+
217270
for meta_ in metadata:
218271
meta_['input_shape'] = {key: list(data.shape) for key, data in batch_input.items()}
219272

220-
if metadata[0].get('output_is_dict_type'):
273+
if metadata[0].get('output_is_dict_type') or isinstance(outputs, dict):
221274
result_dict = self._convert_to_numpy(outputs)
222275
else:
223276
result_dict = {

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers
1717
* `batch` - batch size for running model (Optional, default 1).
1818
* `use_torch_compile` - boolean, use torch.compile to optimize the module code (Optional, default `False`)
1919
* `torch_compile_kwargs` - dictionary of keyword arguments to pass to torch.compile (Optional, default `{}`)
20+
* `transformers_class` - transformers class name to load pre-trained model with `module` name. (Optional).
2021

2122
In turn if you model has several inputs you need to specify them in config, using specific parameter: `inputs`.
2223
Each input description should has following info:

0 commit comments

Comments
 (0)