From 1639e9d96c4662382561270da52d787ba98079d1 Mon Sep 17 00:00:00 2001 From: Andrei Anufriev <andrey.anufriev@intel.com> Date: Wed, 19 Mar 2025 17:19:38 +0100 Subject: [PATCH] Added possibility to generate on GPU. --- tools/who_what_benchmark/whowhatbench/text_evaluator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 73fb1d1928..11b4da028e 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -187,16 +187,20 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): def _generate_data(self, model, gen_answer_fn=None, generation_config=None): def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False): + device = "cpu" + if hasattr(model, "device"): + device = model.device + if use_chat_template: message = [{"role": "user", "content": prompt}] - inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt") + inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device) tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) if crop_question: tokens = tokens[:, inputs.shape[-1]:] res = self.tokenizer.decode(tokens[0], skip_special_tokens=True) return res else: - inputs = self.tokenizer(prompt, return_tensors="pt") + inputs = self.tokenizer(prompt, return_tensors="pt").to(device) tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) if crop_question: tokens = tokens[:, inputs["input_ids"].shape[-1] :]