Skip to content

Commit ab9c8da

Browse files
committed
add ipex tests
1 parent a310706 commit ab9c8da

File tree

3 files changed

+55
-12
lines changed

3 files changed

+55
-12
lines changed

optimum/intel/utils/import_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@
184184
_sentence_transformers_available = False
185185

186186

187+
_langchain_available = importlib.util.find_spec("langchain") is not None
188+
_langchain_version = "N/A"
189+
if _langchain_available:
190+
try:
191+
_langchain_version = importlib.metadata.version("langchain")
192+
except importlib.metadata.PackageNotFoundError:
193+
_langchain_available = False
194+
195+
187196
def is_transformers_available():
188197
return _transformers_available
189198

tests/ipex/test_modeling.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
import requests
2323
import torch
24+
from typing import Generator
2425
from parameterized import parameterized
2526
from PIL import Image
2627
from transformers import (
@@ -34,6 +35,7 @@
3435
pipeline,
3536
set_seed,
3637
)
38+
from packaging import version
3739
from optimum.intel import (
3840
IPEXModel,
3941
IPEXModelForAudioClassification,
@@ -47,7 +49,12 @@
4749
IPEXSentenceTransformer,
4850
)
4951
from optimum.utils.testing_utils import grid_parameters, require_sentence_transformers
50-
from optimum.intel.utils.import_utils import is_sentence_transformers_available, is_torch_version
52+
from optimum.intel.utils.import_utils import (
53+
is_sentence_transformers_available,
54+
is_torch_version,
55+
_langchain_available,
56+
_langchain_version,
57+
)
5158

5259
if is_sentence_transformers_available():
5360
from sentence_transformers import SentenceTransformer
@@ -707,3 +714,35 @@ def test_sentence_transformers_save_and_infer(self, model_arch):
707714
model = IPEXSentenceTransformer(tmpdirname, model_kwargs={"subfolder": "ipex"})
708715
sentences = ["This is an example sentence", "Each sentence is converted"]
709716
model.encode(sentences)
717+
718+
719+
class IPEXLangchainTest(unittest.TestCase):
720+
SUPPORTED_ARCHITECTURES = ("gpt2",)
721+
722+
@unittest.skipIf(
723+
not _langchain_available or version.parse(_langchain_version) <= version.parse("0.3.30"),
724+
reason="Unsupported langchain",
725+
)
726+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
727+
def test_huggingface_pipeline_streaming(self, model_arch):
728+
from langchain_huggingface import HuggingFacePipeline
729+
730+
model_id = MODEL_NAMES[model_arch]
731+
732+
hf_pipe = HuggingFacePipeline.from_model_id(
733+
model_id=model_id,
734+
task="text-generation",
735+
pipeline_kwargs={"max_new_tokens": 10},
736+
backend="ipex",
737+
)
738+
739+
generator = hf_pipe.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
740+
741+
self.assertIsInstance(generator, Generator)
742+
743+
stream_results_string = ""
744+
for chunk in generator:
745+
self.assertIsInstance(chunk, str)
746+
stream_results_string = chunk
747+
748+
self.assertTrue(len(stream_results_string.strip()) > 1)

tests/openvino/test_modeling.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import copy
1616
import gc
17-
import importlib
1817
import os
1918
import platform
2019
import tempfile
@@ -109,7 +108,12 @@
109108
_print_compiled_model_properties,
110109
)
111110
from optimum.intel.pipelines import pipeline as optimum_pipeline
112-
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
111+
from optimum.intel.utils.import_utils import (
112+
_langchain_available,
113+
_langchain_version,
114+
is_openvino_version,
115+
is_transformers_version,
116+
)
113117
from optimum.intel.utils.modeling_utils import _find_files_matching_pattern
114118
from optimum.utils import (
115119
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
@@ -132,15 +136,6 @@
132136
F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"}
133137

134138

135-
_langchain_available = importlib.util.find_spec("langchain") is not None
136-
_langchain_version = "N/A"
137-
if _langchain_available:
138-
try:
139-
_langchain_version = importlib.metadata.version("langchain")
140-
except importlib.metadata.PackageNotFoundError:
141-
_langchain_available = False
142-
143-
144139
class Timer(object):
145140
def __enter__(self):
146141
self.elapsed = time.perf_counter()

0 commit comments

Comments
 (0)