Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ac809f8

Browse files
committedMar 7, 2024·
update testing
1 parent 1b8d76a commit ac809f8

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed
 

‎optimum/intel/ipex/modeling_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def forward(
193193
inputs["token_type_ids"] = token_type_ids
194194

195195
outputs = self._call_model(**inputs)
196-
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
196+
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0])
197197

198198
def eval(self):
199199
self.model.eval()
@@ -282,7 +282,7 @@ def forward(
282282
inputs["attention_mask"] = attention_mask
283283

284284
outputs = self._call_model(**inputs)
285-
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(last_hidden_state=outputs[0])
285+
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
286286

287287

288288
class IPEXModelForQuestionAnswering(IPEXModel):

‎tests/ipex/test_modeling.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,10 @@ def test_compare_to_transformers(self, model_arch):
126126
with torch.no_grad():
127127
transformers_outputs = transformers_model(**tokens)
128128
outputs = ipex_model(**tokens)
129-
self.assertTrue(
130-
torch.allclose(outputs["last_hidden_state"], transformers_outputs["last_hidden_state"], atol=1e-4)
131-
)
129+
# Compare tensor outputs
130+
for output_name in {"logits", "last_hidden_state"}:
131+
if output_name in transformers_outputs:
132+
self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4))
132133

133134
@parameterized.expand(SUPPORTED_ARCHITECTURES)
134135
def test_pipeline(self, model_arch):

0 commit comments

Comments
 (0)
Please sign in to comment.