Skip to content

Commit 1639e9d

Browse files
committed
Added possibility to generate on GPU.
1 parent 19744f5 commit 1639e9d

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tools/who_what_benchmark/whowhatbench/text_evaluator.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -187,16 +187,20 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
187187

188188
def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
189189
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
190+
device = "cpu"
191+
if hasattr(model, "device"):
192+
device = model.device
193+
190194
if use_chat_template:
191195
message = [{"role": "user", "content": prompt}]
192-
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
196+
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
193197
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
194198
if crop_question:
195199
tokens = tokens[:, inputs.shape[-1]:]
196200
res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
197201
return res
198202
else:
199-
inputs = self.tokenizer(prompt, return_tensors="pt")
203+
inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
200204
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
201205
if crop_question:
202206
tokens = tokens[:, inputs["input_ids"].shape[-1] :]

0 commit comments

Comments
 (0)