diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index bac72a4e11..14009dace7 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -189,29 +189,24 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): def _generate_data(self, model, gen_answer_fn=None, generation_config=None): def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False): is_awq = getattr(model, "is_awq", None) is not None + device = "cpu" + if hasattr(model, "device"): + device = model.device if use_chat_template: message = [{"role": "user", "content": prompt}] - inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt") - if is_awq: - with patch_awq_for_inference(is_awq): - tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) - else: - tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) - if crop_question: - tokens = tokens[:, inputs.shape[-1]:] - res = self.tokenizer.decode(tokens[0], skip_special_tokens=True) - return res + inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(device) else: - inputs = self.tokenizer(prompt, return_tensors="pt") - if is_awq: - with patch_awq_for_inference(is_awq): - tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) - else: + inputs = self.tokenizer(prompt, return_tensors="pt").to(device) + + if is_awq: + with patch_awq_for_inference(is_awq): tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) - if crop_question: - tokens = tokens[:, inputs["input_ids"].shape[-1] :] - return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] + else: + tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) + if crop_question: + tokens = tokens[:, inputs["input_ids"].shape[-1] :] + return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0] gen_answer_fn = gen_answer_fn or default_gen_answer