From 1639e9d96c4662382561270da52d787ba98079d1 Mon Sep 17 00:00:00 2001
From: Andrei Anufriev <andrey.anufriev@intel.com>
Date: Wed, 19 Mar 2025 17:19:38 +0100
Subject: [PATCH] Added possibility to generate on GPU.

---
 tools/who_what_benchmark/whowhatbench/text_evaluator.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/tools/who_what_benchmark/whowhatbench/text_evaluator.py b/tools/who_what_benchmark/whowhatbench/text_evaluator.py
index 73fb1d1928..11b4da028e 100644
--- a/tools/who_what_benchmark/whowhatbench/text_evaluator.py
+++ b/tools/who_what_benchmark/whowhatbench/text_evaluator.py
@@ -187,16 +187,20 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
 
     def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
         def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
+            device = "cpu"
+            if hasattr(model, "device"):
+                device = model.device
+
             if use_chat_template:
                 message = [{"role": "user", "content": prompt}]
-                inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
+                inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
                 tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
                 if crop_question:
                     tokens = tokens[:, inputs.shape[-1]:]
                 res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
                 return res
             else:
-                inputs = self.tokenizer(prompt, return_tensors="pt")
+                inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
                 tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
                 if crop_question:
                     tokens = tokens[:, inputs["input_ids"].shape[-1] :]