Skip to content

Commit 2925853

Browse files
committed
Support native bert models with transformers class
1 parent b38bb3e commit 2925853

File tree

1 file changed

+41
-25
lines changed

1 file changed

+41
-25
lines changed

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2929
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
3030
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
31-
31+
SCALAR_INPUTS = ('input_ids', 'input_mask', 'segment_ids', 'attention_mask', 'token_type_ids')
3232

3333
class PyTorchLauncher(Launcher):
3434
__provider__ = 'pytorch'
@@ -99,26 +99,29 @@ def __init__(self, config_entry: dict, *args, **kwargs):
9999
module_kwargs = config_entry.get("module_kwargs", {})
100100
self.device = self.get_value_from_config('device')
101101
self.cuda = 'cuda' in self.device
102+
102103
checkpoint = config_entry.get('checkpoint')
104+
if checkpoint is None:
105+
checkpoint = config_entry.get('checkpoint_url')
106+
107+
python_path = config_entry.get("python_path")
108+
103109
if self.tranformers_class:
104110
self.module = self.load_tranformers_module(
105-
config_entry['module']
111+
config_entry['module'], python_path
106112
)
107113
else:
108-
if checkpoint is None:
109-
checkpoint = config_entry.get('checkpoint_url')
110114

111115
self.module = self.load_module(
112116
config_entry['module'],
113117
module_args,
114118
module_kwargs,
115119
checkpoint,
116120
config_entry.get('state_key'),
117-
config_entry.get("python_path"),
121+
python_path,
118122
config_entry.get("init_method")
119123
)
120124

121-
122125
self._batch = self.get_value_from_config('batch')
123126
# torch modules does not have input information
124127
self._generate_inputs()
@@ -175,14 +178,15 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
175178

176179
return self.prepare_module(module, model_cls)
177180

178-
def load_tranformers_module(self, pretrained_name):
179-
import transformers # pylint: disable=C0415
180-
model_class = getattr(transformers, self.tranformers_class)
181-
module = model_class.from_pretrained(pretrained_name)
181+
def load_tranformers_module(self, pretrained_name, python_path):
182+
with append_to_path(python_path):
183+
import transformers # pylint: disable=C0415
184+
model_class = getattr(transformers, self.tranformers_class)
185+
pretrained_model = python_path if python_path else pretrained_name
186+
module = model_class.from_pretrained(pretrained_model)
182187

183188
return self.prepare_module(module, model_class)
184189

185-
186190
def prepare_module(self, module, model_class):
187191
module.to('cuda' if self.cuda else 'cpu')
188192
module.eval()
@@ -191,16 +195,23 @@ def prepare_module(self, module, model_class):
191195
if hasattr(model_class, 'compile'):
192196
module.compile()
193197
module = self._torch.compile(module, **self.compile_kwargs)
198+
194199
return module
195200

196201

197202
def _convert_to_tensor(self, value, precision):
198203
if isinstance(value, self._torch.Tensor):
199204
return value
200-
return self._torch.from_numpy(value.astype(np.float32 if not precision else precision)).to(self.device)
205+
if precision is None:
206+
precision = np.float32
207+
208+
return self._torch.from_numpy(value.astype(precision)).to(self.device)
201209

202210
def fit_to_input(self, data, layer_name, layout, precision, template=None):
203211

212+
if precision is None and layer_name in SCALAR_INPUTS:
213+
precision = np.int64
214+
204215
if layer_name == 'input' and isinstance(data[0], dict):
205216
tensor_dict = {}
206217
for key, val in data[0].items():
@@ -214,12 +225,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
214225

215226
return tensor_dict
216227

217-
if layout is not None:
228+
data_shape = np.shape(data)
229+
230+
if layout is not None and len(data_shape) == len(layout):
218231
data = np.transpose(data, layout)
232+
else:
233+
data = np.array(data)
219234

220-
tensor = self._torch.from_numpy(data.astype(np.float32 if not precision else precision))
221-
tensor = tensor.to(self.device)
222-
return tensor
235+
return self._convert_to_tensor(data, precision)
223236

224237
def _convert_to_numpy(self, input_dict):
225238
numpy_dict = {}
@@ -230,25 +243,28 @@ def _convert_to_numpy(self, input_dict):
230243
numpy_dict[key] = value
231244
return numpy_dict
232245

246+
247+
def forward(self, outputs):
248+
if hasattr(outputs, 'logits') and 'logits' in self.output_names:
249+
return {'logits': outputs.logits}
250+
elif hasattr(outputs, 'last_hidden_state') and 'last_hidden_state' in self.output_names:
251+
return {'last_hidden_state': outputs.last_hidden_state}
252+
return list(outputs)
253+
233254
def predict(self, inputs, metadata=None, **kwargs):
234255
results = []
235256
with self._torch.no_grad():
236257
for batch_input in inputs:
237-
if metadata[0].get('input_is_dict_type'):
258+
if metadata[0].get('input_is_dict_type') or \
259+
(isinstance(batch_input, dict) and 'input' in batch_input):
238260
outputs = self.module(batch_input['input'])
239261
else:
240-
output = self.module(*batch_input.values())
241-
242-
if 'logits' in self.output_names:
243-
result_dict = { 'logits': output.logits.detach().cpu().numpy() }
244-
results.append(result_dict)
245-
continue
246-
outputs = list(output)
262+
outputs = self.module(**batch_input)
247263

248264
for meta_ in metadata:
249265
meta_['input_shape'] = {key: list(data.shape) for key, data in batch_input.items()}
250266

251-
if metadata[0].get('output_is_dict_type'):
267+
if metadata[0].get('output_is_dict_type') or isinstance(outputs, dict):
252268
result_dict = self._convert_to_numpy(outputs)
253269
else:
254270
result_dict = {

0 commit comments

Comments
 (0)