Skip to content

Commit 76e614a

Browse files
authored
Added possibility to generate base text on GPU for text evaluation. (#1945)
Convert input device type according to device type of model for text evaluator.
1 parent 5d986b7 commit 76e614a

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

tools/who_what_benchmark/whowhatbench/text_evaluator.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,29 +189,24 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
189189
def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
190190
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
191191
is_awq = getattr(model, "is_awq", None) is not None
192+
device = "cpu"
193+
if hasattr(model, "device"):
194+
device = model.device
192195

193196
if use_chat_template:
194197
message = [{"role": "user", "content": prompt}]
195-
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
196-
if is_awq:
197-
with patch_awq_for_inference(is_awq):
198-
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
199-
else:
200-
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
201-
if crop_question:
202-
tokens = tokens[:, inputs.shape[-1]:]
203-
res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
204-
return res
198+
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(device)
205199
else:
206-
inputs = self.tokenizer(prompt, return_tensors="pt")
207-
if is_awq:
208-
with patch_awq_for_inference(is_awq):
209-
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
210-
else:
200+
inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
201+
202+
if is_awq:
203+
with patch_awq_for_inference(is_awq):
211204
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
212-
if crop_question:
213-
tokens = tokens[:, inputs["input_ids"].shape[-1] :]
214-
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
205+
else:
206+
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
207+
if crop_question:
208+
tokens = tokens[:, inputs["input_ids"].shape[-1] :]
209+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
215210

216211
gen_answer_fn = gen_answer_fn or default_gen_answer
217212

0 commit comments

Comments
 (0)