@@ -189,29 +189,24 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
189189 def _generate_data (self , model , gen_answer_fn = None , generation_config = None ):
190190 def default_gen_answer (model , tokenizer , prompt , max_new_tokens , crop_question , use_chat_template = False ):
191191 is_awq = getattr (model , "is_awq" , None ) is not None
192+ device = "cpu"
193+ if hasattr (model , "device" ):
194+ device = model .device
192195
193196 if use_chat_template :
194197 message = [{"role" : "user" , "content" : prompt }]
195- inputs = tokenizer .apply_chat_template (message , tokenize = True , add_generation_prompt = True , return_tensors = "pt" )
196- if is_awq :
197- with patch_awq_for_inference (is_awq ):
198- tokens = model .generate (inputs , do_sample = False , max_new_tokens = max_new_tokens )
199- else :
200- tokens = model .generate (inputs , do_sample = False , max_new_tokens = max_new_tokens )
201- if crop_question :
202- tokens = tokens [:, inputs .shape [- 1 ]:]
203- res = self .tokenizer .decode (tokens [0 ], skip_special_tokens = True )
204- return res
198+ inputs = tokenizer .apply_chat_template (message , tokenize = True , add_generation_prompt = True , return_tensors = "pt" , return_dict = True ).to (device )
205199 else :
206- inputs = self .tokenizer (prompt , return_tensors = "pt" )
207- if is_awq :
208- with patch_awq_for_inference (is_awq ):
209- tokens = model .generate (** inputs , do_sample = False , max_new_tokens = max_new_tokens )
210- else :
200+ inputs = self .tokenizer (prompt , return_tensors = "pt" ).to (device )
201+
202+ if is_awq :
203+ with patch_awq_for_inference (is_awq ):
211204 tokens = model .generate (** inputs , do_sample = False , max_new_tokens = max_new_tokens )
212- if crop_question :
213- tokens = tokens [:, inputs ["input_ids" ].shape [- 1 ] :]
214- return self .tokenizer .batch_decode (tokens , skip_special_tokens = True )[0 ]
205+ else :
206+ tokens = model .generate (** inputs , do_sample = False , max_new_tokens = max_new_tokens )
207+ if crop_question :
208+ tokens = tokens [:, inputs ["input_ids" ].shape [- 1 ] :]
209+ return self .tokenizer .batch_decode (tokens , skip_special_tokens = True )[0 ]
215210
216211 gen_answer_fn = gen_answer_fn or default_gen_answer
217212
0 commit comments