2020import urllib
2121import re
2222from collections import OrderedDict
23-
2423import numpy as np
2524from ..config import PathField , StringField , DictField , NumberField , ListField , BoolField
25+ from ..utils import UnsupportedPackage
2626from .launcher import Launcher
27+ try :
28+ import transformers
29+ except ImportError as transformers_error :
30+ transformers = UnsupportedPackage ('transformers' , transformers_error .msg )
2731
32+ CLASS_REGEX = r'(?:\w+)'
2833MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2934DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
3035CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
31-
36+ SCALAR_INPUTS = ( 'input_ids' , 'input_mask' , 'segment_ids' , 'attention_mask' , 'token_type_ids' )
3237
3338class PyTorchLauncher (Launcher ):
3439 __provider__ = 'pytorch'
@@ -67,6 +72,9 @@ def parameters(cls):
6772 'torch_compile_kwargs' : DictField (
6873 key_type = str , validate_values = False , optional = True , default = {},
6974 description = "dictionary of keyword arguments passed to torch.compile"
75+ ),
76+ 'transformers_class' : StringField (
77+ optional = True , regex = CLASS_REGEX , description = 'Transformers class name to load pre-trained module.'
7078 )
7179 })
7280 return parameters
@@ -84,29 +92,40 @@ def __init__(self, config_entry: dict, *args, **kwargs):
8492 self .validate_config (config_entry )
8593 self .use_torch_compile = config_entry .get ('use_torch_compile' , False )
8694 self .compile_kwargs = config_entry .get ('torch_compile_kwargs' , {})
95+ self .tranformers_class = config_entry .get ('transformers_class' , None )
8796 backend = self .compile_kwargs .get ('backend' , None )
8897 if self .use_torch_compile and backend == 'openvino' :
8998 try :
90- import openvino .torch # pylint: disable=C0415, W0611
99+ importlib . import_module ( ' openvino.torch' ) # pylint: disable=C0415, W0611
91100 except ImportError as import_error :
92101 raise ValueError ("torch.compile is supported from OpenVINO 2023.1\n {}" .format (
93102 import_error .msg )) from import_error
94103 module_args = config_entry .get ("module_args" , ())
95104 module_kwargs = config_entry .get ("module_kwargs" , {})
96105 self .device = self .get_value_from_config ('device' )
97106 self .cuda = 'cuda' in self .device
107+
98108 checkpoint = config_entry .get ('checkpoint' )
99109 if checkpoint is None :
100110 checkpoint = config_entry .get ('checkpoint_url' )
101- self .module = self .load_module (
102- config_entry ['module' ],
103- module_args ,
104- module_kwargs ,
105- checkpoint ,
106- config_entry .get ('state_key' ),
107- config_entry .get ("python_path" ),
108- config_entry .get ("init_method" )
109- )
111+
112+ python_path = config_entry .get ("python_path" )
113+
114+ if self .tranformers_class :
115+ self .module = self .load_tranformers_module (
116+ config_entry ['module' ], python_path
117+ )
118+ else :
119+
120+ self .module = self .load_module (
121+ config_entry ['module' ],
122+ module_args ,
123+ module_kwargs ,
124+ checkpoint ,
125+ config_entry .get ('state_key' ),
126+ python_path ,
127+ config_entry .get ("init_method" )
128+ )
110129
111130 self ._batch = self .get_value_from_config ('batch' )
112131 # torch modules does not have input information
@@ -142,7 +161,7 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
142161 model_cls = module_parts [- 1 ]
143162 model_path = "." .join (module_parts [:- 1 ])
144163 with append_to_path (python_path ):
145- model_cls = importlib .import_module (model_path ). __getattribute__ ( model_cls )
164+ model_cls = getattr ( importlib .import_module (model_path ), model_cls )
146165 module = model_cls (* module_args , ** module_kwargs )
147166 if init_method is not None :
148167 if hasattr (model_cls , init_method ):
@@ -153,31 +172,53 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
153172
154173 if checkpoint :
155174 if isinstance (checkpoint , str ) and re .match (CHECKPOINT_URL_REGEX , checkpoint ):
156- checkpoint = urllib .request .urlretrieve (checkpoint )[0 ]
175+ checkpoint = urllib .request .urlretrieve (checkpoint )[0 ] # nosec B310 # disable urllib-urlopen check
157176 checkpoint = self ._torch .load (
158177 checkpoint , map_location = None if self .cuda else self ._torch .device ('cpu' )
159178 )
160179 state = checkpoint if not state_key else checkpoint [state_key ]
161180 if all (key .startswith ('module.' ) for key in state ):
162181 module = self ._torch .nn .DataParallel (module )
163182 module .load_state_dict (state , strict = False )
164- module .to ('cuda' if self .cuda else 'cpu' )
165- module .eval ()
166183
167- if self .use_torch_compile :
168- if hasattr (model_cls , 'compile' ):
169- module .compile ()
170- module = self ._torch .compile (module , ** self .compile_kwargs )
184+ return self .prepare_module (module , model_cls )
185+
186+ def load_tranformers_module (self , pretrained_name , python_path ):
187+ with append_to_path (python_path ):
188+ if isinstance (transformers , UnsupportedPackage ):
189+ transformers .raise_error (self .__class__ .__name__ )
190+
191+ model_class = getattr (transformers , self .tranformers_class )
192+ pretrained_model = python_path if python_path else pretrained_name
193+ module = model_class .from_pretrained (pretrained_model )
194+
195+ return self .prepare_module (module , model_class )
196+
197+ def prepare_module (self , module , model_class ):
198+ module .to ('cuda' if self .cuda else 'cpu' )
199+ module .eval ()
200+
201+ if self .use_torch_compile :
202+ if hasattr (model_class , 'compile' ):
203+ module .compile ()
204+ module = self ._torch .compile (module , ** self .compile_kwargs )
205+
206+ return module
171207
172- return module
173208
174209 def _convert_to_tensor (self , value , precision ):
175210 if isinstance (value , self ._torch .Tensor ):
176211 return value
177- return self ._torch .from_numpy (value .astype (np .float32 if not precision else precision )).to (self .device )
212+ if precision is None :
213+ precision = np .float32
214+
215+ return self ._torch .from_numpy (value .astype (precision )).to (self .device )
178216
179217 def fit_to_input (self , data , layer_name , layout , precision , template = None ):
180218
219+ if precision is None and layer_name in SCALAR_INPUTS :
220+ precision = np .int64
221+
181222 if layer_name == 'input' and isinstance (data [0 ], dict ):
182223 tensor_dict = {}
183224 for key , val in data [0 ].items ():
@@ -191,11 +232,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
191232
192233 return tensor_dict
193234
194- if layout is not None :
235+ data_shape = np .shape (data )
236+
237+ if layout is not None and len (data_shape ) == len (layout ):
195238 data = np .transpose (data , layout )
196- tensor = self ._torch .from_numpy (data .astype (np .float32 if not precision else precision ))
197- tensor = tensor .to (self .device )
198- return tensor
239+ else :
240+ data = np .array (data )
241+
242+ return self ._convert_to_tensor (data , precision )
199243
200244 def _convert_to_numpy (self , input_dict ):
201245 numpy_dict = {}
@@ -206,18 +250,27 @@ def _convert_to_numpy(self, input_dict):
206250 numpy_dict [key ] = value
207251 return numpy_dict
208252
253+
254+ def forward (self , outputs ):
255+ if hasattr (outputs , 'logits' ) and 'logits' in self .output_names :
256+ return {'logits' : outputs .logits }
257+ if hasattr (outputs , 'last_hidden_state' ) and 'last_hidden_state' in self .output_names :
258+ return {'last_hidden_state' : outputs .last_hidden_state }
259+ return list (outputs )
260+
209261 def predict (self , inputs , metadata = None , ** kwargs ):
210262 results = []
211263 with self ._torch .no_grad ():
212264 for batch_input in inputs :
213- if metadata [0 ].get ('input_is_dict_type' ):
265+ if metadata [0 ].get ('input_is_dict_type' ) or ( isinstance ( batch_input , dict ) and 'input' in batch_input ) :
214266 outputs = self .module (batch_input ['input' ])
215267 else :
216- outputs = list (self .module (* batch_input .values ()))
268+ outputs = self .module (** batch_input )
269+
217270 for meta_ in metadata :
218271 meta_ ['input_shape' ] = {key : list (data .shape ) for key , data in batch_input .items ()}
219272
220- if metadata [0 ].get ('output_is_dict_type' ):
273+ if metadata [0 ].get ('output_is_dict_type' ) or isinstance ( outputs , dict ) :
221274 result_dict = self ._convert_to_numpy (outputs )
222275 else :
223276 result_dict = {
0 commit comments