From 0d148418fd480ae25c636728e812deb14628e9b5 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Fri, 21 Mar 2025 17:17:57 +0100 Subject: [PATCH] Added possibility to generate base text on GPU for text evaluation --- .../who_what_benchmark/whowhatbench/text_evaluator.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py index 73fb1d1928..666ecc07c8 100644 --- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py +++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py @@ -187,16 +187,22 @@ def worst_examples(self, top_k: int = 5, metric="similarity"): def _generate_data(self, model, gen_answer_fn=None, generation_config=None): def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False): + device = "cpu" + if hasattr(model, "device"): + device = model.device + if use_chat_template: message = [{"role": "user", "content": prompt}] - inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt") + inputs = tokenizer.apply_chat_template( + message, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ).to(device) tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens) if crop_question: tokens = tokens[:, inputs.shape[-1]:] res = self.tokenizer.decode(tokens[0], skip_special_tokens=True) return res else: - inputs = self.tokenizer(prompt, return_tensors="pt") + inputs = self.tokenizer(prompt, return_tensors="pt").to(device) tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens) if crop_question: tokens = tokens[:, inputs["input_ids"].shape[-1] :]