Skip to content

Commit 87b36db

Browse files
authored
Add IPEX model for question answering (#534)
* add IPEX model for QA task * add fix
1 parent 805e737 commit 87b36db

8 files changed

+72
-4
lines changed

optimum/intel/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"IPEXModelForSequenceClassification",
4848
"IPEXModelForMaskedLM",
4949
"IPEXModelForTokenClassification",
50+
"IPEXModelForQuestionAnswering",
5051
]
5152

5253

@@ -160,6 +161,7 @@
160161
from .ipex import (
161162
IPEXModelForCausalLM,
162163
IPEXModelForMaskedLM,
164+
IPEXModelForQuestionAnswering,
163165
IPEXModelForSequenceClassification,
164166
IPEXModelForTokenClassification,
165167
inference_mode,

optimum/intel/ipex/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from optimum.intel.ipex.modeling_base import (
22
IPEXModelForCausalLM,
33
IPEXModelForMaskedLM,
4+
IPEXModelForQuestionAnswering,
45
IPEXModelForSequenceClassification,
56
IPEXModelForTokenClassification,
67
)

optimum/intel/ipex/inference.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
IPEXMPTForCausalLM,
3636
IPEXOPTForCausalLM,
3737
IPEXGPTBigCodeForCausalLM,
38+
IPEXModelForQuestionAnswering,
3839
)
3940

4041

optimum/intel/ipex/modeling_base.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
AutoModel,
2828
AutoModelForCausalLM,
2929
AutoModelForMaskedLM,
30+
AutoModelForQuestionAnswering,
3031
AutoModelForSequenceClassification,
3132
AutoModelForTokenClassification,
3233
GenerationConfig,
@@ -171,7 +172,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
171172

172173
def forward(self, *args, **kwargs):
173174
outputs = self.model(*args, **kwargs)
174-
return ModelOutput(logits=outputs["logits"] if isinstance(outputs, dict) else outputs[0])
175+
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
175176

176177
def eval(self):
177178
self.model.eval()
@@ -205,6 +206,17 @@ class IPEXModelForTokenClassification(IPEXModel):
205206
export_feature = "token-classification"
206207

207208

209+
class IPEXModelForQuestionAnswering(IPEXModel):
210+
auto_model_class = AutoModelForQuestionAnswering
211+
export_feature = "question-answering"
212+
213+
def forward(self, *args, **kwargs):
214+
outputs = self.model(*args, **kwargs)
215+
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
216+
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
217+
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
218+
219+
208220
class IPEXModelForCausalLM(IPEXModel, GenerationMixin):
209221
auto_model_class = AutoModelForCausalLM
210222
export_feature = "text-generation"

optimum/intel/ipex/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717
"text-generation": "IPEXModelForCausalLM",
1818
"text-classification": "IPEXModelForSequenceClassification",
1919
"token-classification": "IPEXModelForTokenClassification",
20-
# "question-answering": "IPEXModelForQuestionAnswering",
20+
"question-answering": "IPEXModelForQuestionAnswering",
2121
}

optimum/intel/utils/dummy_ipex_objects.py

+11
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,14 @@ def __init__(self, *args, **kwargs):
6464
@classmethod
6565
def from_pretrained(cls, *args, **kwargs):
6666
requires_backends(cls, ["ipex"])
67+
68+
69+
class IPEXModelForQuestionAnswering(metaclass=DummyObject):
70+
_backends = ["ipex"]
71+
72+
def __init__(self, *args, **kwargs):
73+
requires_backends(self, ["ipex"])
74+
75+
@classmethod
76+
def from_pretrained(cls, *args, **kwargs):
77+
requires_backends(cls, ["ipex"])

tests/ipex/test_inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_question_answering_pipeline_inference(self, model_arch):
7878
outputs_ipex = ipex_pipe(
7979
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
8080
)
81-
# self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule))
81+
self.assertTrue(isinstance(ipex_pipe.model._optimized.model, torch.jit.RecursiveScriptModule))
8282
self.assertEqual(outputs["start"], outputs_ipex["start"])
8383
self.assertEqual(outputs["end"], outputs_ipex["end"])
8484

tests/ipex/test_modeling.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from parameterized import parameterized
2121
from transformers import (
2222
AutoModelForCausalLM,
23+
AutoModelForQuestionAnswering,
2324
AutoModelForSequenceClassification,
2425
AutoTokenizer,
2526
PretrainedConfig,
@@ -28,7 +29,7 @@
2829
)
2930

3031
from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS
31-
from optimum.intel import IPEXModelForCausalLM, IPEXModelForSequenceClassification
32+
from optimum.intel import IPEXModelForCausalLM, IPEXModelForQuestionAnswering, IPEXModelForSequenceClassification
3233

3334

3435
SEED = 42
@@ -118,6 +119,46 @@ def test_pipeline(self, model_arch):
118119
self.assertIsInstance(outputs[0]["label"], str)
119120

120121

122+
class IPEXModelForQuestionAnsweringTest(unittest.TestCase):
123+
SUPPORTED_ARCHITECTURES = (
124+
"bert",
125+
"distilbert",
126+
"roberta",
127+
)
128+
129+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
130+
def test_compare_to_transformers(self, model_arch):
131+
model_id = MODEL_NAMES[model_arch]
132+
set_seed(SEED)
133+
ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True)
134+
self.assertIsInstance(ipex_model.config, PretrainedConfig)
135+
transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id)
136+
tokenizer = AutoTokenizer.from_pretrained(model_id)
137+
inputs = "This is a sample input"
138+
tokens = tokenizer(inputs, return_tensors="pt")
139+
with torch.no_grad():
140+
transformers_outputs = transformers_model(**tokens)
141+
outputs = ipex_model(**tokens)
142+
self.assertIn("start_logits", outputs)
143+
self.assertIn("end_logits", outputs)
144+
# Compare tensor outputs
145+
self.assertTrue(torch.allclose(outputs.start_logits, transformers_outputs.start_logits, atol=1e-4))
146+
self.assertTrue(torch.allclose(outputs.end_logits, transformers_outputs.end_logits, atol=1e-4))
147+
148+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
149+
def test_pipeline(self, model_arch):
150+
model_id = MODEL_NAMES[model_arch]
151+
model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True)
152+
tokenizer = AutoTokenizer.from_pretrained(model_id)
153+
pipe = pipeline("question-answering", model=model, tokenizer=tokenizer)
154+
question = "What's my name?"
155+
context = "My Name is Sasha and I live in Lyon."
156+
outputs = pipe(question, context)
157+
self.assertEqual(pipe.device, model.device)
158+
self.assertGreaterEqual(outputs["score"], 0.0)
159+
self.assertIsInstance(outputs["answer"], str)
160+
161+
121162
class IPEXModelForCausalLMTest(unittest.TestCase):
122163
SUPPORTED_ARCHITECTURES = (
123164
"bart",

0 commit comments

Comments
 (0)