2828MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2929DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
3030CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
31-
31+ SCALAR_INPUTS = ( 'input_ids' , 'input_mask' , 'segment_ids' , 'attention_mask' , 'token_type_ids' )
3232
3333class PyTorchLauncher (Launcher ):
3434 __provider__ = 'pytorch'
@@ -99,26 +99,29 @@ def __init__(self, config_entry: dict, *args, **kwargs):
9999 module_kwargs = config_entry .get ("module_kwargs" , {})
100100 self .device = self .get_value_from_config ('device' )
101101 self .cuda = 'cuda' in self .device
102+
102103 checkpoint = config_entry .get ('checkpoint' )
104+ if checkpoint is None :
105+ checkpoint = config_entry .get ('checkpoint_url' )
106+
107+ python_path = config_entry .get ("python_path" )
108+
103109 if self .tranformers_class :
104110 self .module = self .load_tranformers_module (
105- config_entry ['module' ]
111+ config_entry ['module' ], python_path
106112 )
107113 else :
108- if checkpoint is None :
109- checkpoint = config_entry .get ('checkpoint_url' )
110114
111115 self .module = self .load_module (
112116 config_entry ['module' ],
113117 module_args ,
114118 module_kwargs ,
115119 checkpoint ,
116120 config_entry .get ('state_key' ),
117- config_entry . get ( " python_path" ) ,
121+ python_path ,
118122 config_entry .get ("init_method" )
119123 )
120124
121-
122125 self ._batch = self .get_value_from_config ('batch' )
123126 # torch modules does not have input information
124127 self ._generate_inputs ()
@@ -175,14 +178,15 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
175178
176179 return self .prepare_module (module , model_cls )
177180
178- def load_tranformers_module (self , pretrained_name ):
179- import transformers # pylint: disable=C0415
180- model_class = getattr (transformers , self .tranformers_class )
181- module = model_class .from_pretrained (pretrained_name )
181+ def load_tranformers_module (self , pretrained_name , python_path ):
182+ with append_to_path (python_path ):
183+ import transformers # pylint: disable=C0415
184+ model_class = getattr (transformers , self .tranformers_class )
185+ pretrained_model = python_path if python_path else pretrained_name
186+ module = model_class .from_pretrained (pretrained_model )
182187
183188 return self .prepare_module (module , model_class )
184189
185-
186190 def prepare_module (self , module , model_class ):
187191 module .to ('cuda' if self .cuda else 'cpu' )
188192 module .eval ()
@@ -191,16 +195,23 @@ def prepare_module(self, module, model_class):
191195 if hasattr (model_class , 'compile' ):
192196 module .compile ()
193197 module = self ._torch .compile (module , ** self .compile_kwargs )
198+
194199 return module
195200
196201
197202 def _convert_to_tensor (self , value , precision ):
198203 if isinstance (value , self ._torch .Tensor ):
199204 return value
200- return self ._torch .from_numpy (value .astype (np .float32 if not precision else precision )).to (self .device )
205+ if precision is None :
206+ precision = np .float32
207+
208+ return self ._torch .from_numpy (value .astype (precision )).to (self .device )
201209
202210 def fit_to_input (self , data , layer_name , layout , precision , template = None ):
203211
212+ if precision is None and layer_name in SCALAR_INPUTS :
213+ precision = np .int64
214+
204215 if layer_name == 'input' and isinstance (data [0 ], dict ):
205216 tensor_dict = {}
206217 for key , val in data [0 ].items ():
@@ -214,12 +225,14 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
214225
215226 return tensor_dict
216227
217- if layout is not None :
228+ data_shape = np .shape (data )
229+
230+ if layout is not None and len (data_shape ) == len (layout ):
218231 data = np .transpose (data , layout )
232+ else :
233+ data = np .array (data )
219234
220- tensor = self ._torch .from_numpy (data .astype (np .float32 if not precision else precision ))
221- tensor = tensor .to (self .device )
222- return tensor
235+ return self ._convert_to_tensor (data , precision )
223236
224237 def _convert_to_numpy (self , input_dict ):
225238 numpy_dict = {}
@@ -230,25 +243,28 @@ def _convert_to_numpy(self, input_dict):
230243 numpy_dict [key ] = value
231244 return numpy_dict
232245
246+
247+ def forward (self , outputs ):
248+ if hasattr (outputs , 'logits' ) and 'logits' in self .output_names :
249+ return {'logits' : outputs .logits }
250+ elif hasattr (outputs , 'last_hidden_state' ) and 'last_hidden_state' in self .output_names :
251+ return {'last_hidden_state' : outputs .last_hidden_state }
252+ return list (outputs )
253+
233254 def predict (self , inputs , metadata = None , ** kwargs ):
234255 results = []
235256 with self ._torch .no_grad ():
236257 for batch_input in inputs :
237- if metadata [0 ].get ('input_is_dict_type' ):
258+ if metadata [0 ].get ('input_is_dict_type' ) or \
259+ (isinstance (batch_input , dict ) and 'input' in batch_input ):
238260 outputs = self .module (batch_input ['input' ])
239261 else :
240- output = self .module (* batch_input .values ())
241-
242- if 'logits' in self .output_names :
243- result_dict = { 'logits' : output .logits .detach ().cpu ().numpy () }
244- results .append (result_dict )
245- continue
246- outputs = list (output )
262+ outputs = self .module (** batch_input )
247263
248264 for meta_ in metadata :
249265 meta_ ['input_shape' ] = {key : list (data .shape ) for key , data in batch_input .items ()}
250266
251- if metadata [0 ].get ('output_is_dict_type' ):
267+ if metadata [0 ].get ('output_is_dict_type' ) or isinstance ( outputs , dict ) :
252268 result_dict = self ._convert_to_numpy (outputs )
253269 else :
254270 result_dict = {
0 commit comments