@@ -187,16 +187,20 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
187
187
188
188
def _generate_data (self , model , gen_answer_fn = None , generation_config = None ):
189
189
def default_gen_answer (model , tokenizer , prompt , max_new_tokens , crop_question , use_chat_template = False ):
190
+ device = "cpu"
191
+ if hasattr (model , "device" ):
192
+ device = model .device
193
+
190
194
if use_chat_template :
191
195
message = [{"role" : "user" , "content" : prompt }]
192
- inputs = tokenizer .apply_chat_template (message , tokenize = True , add_generation_prompt = True , return_tensors = "pt" )
196
+ inputs = tokenizer .apply_chat_template (message , tokenize = True , add_generation_prompt = True , return_tensors = "pt" ). to ( device )
193
197
tokens = model .generate (inputs , do_sample = False , max_new_tokens = max_new_tokens )
194
198
if crop_question :
195
199
tokens = tokens [:, inputs .shape [- 1 ]:]
196
200
res = self .tokenizer .decode (tokens [0 ], skip_special_tokens = True )
197
201
return res
198
202
else :
199
- inputs = self .tokenizer (prompt , return_tensors = "pt" )
203
+ inputs = self .tokenizer (prompt , return_tensors = "pt" ). to ( device )
200
204
tokens = model .generate (** inputs , do_sample = False , max_new_tokens = max_new_tokens )
201
205
if crop_question :
202
206
tokens = tokens [:, inputs ["input_ids" ].shape [- 1 ] :]
0 commit comments