|
21 | 21 | import numpy as np
|
22 | 22 | import requests
|
23 | 23 | import torch
|
| 24 | +from typing import Generator |
24 | 25 | from parameterized import parameterized
|
25 | 26 | from PIL import Image
|
26 | 27 | from transformers import (
|
|
34 | 35 | pipeline,
|
35 | 36 | set_seed,
|
36 | 37 | )
|
| 38 | +from packaging import version |
37 | 39 | from optimum.intel import (
|
38 | 40 | IPEXModel,
|
39 | 41 | IPEXModelForAudioClassification,
|
|
47 | 49 | IPEXSentenceTransformer,
|
48 | 50 | )
|
49 | 51 | from optimum.utils.testing_utils import grid_parameters, require_sentence_transformers
|
50 |
| -from optimum.intel.utils.import_utils import is_sentence_transformers_available, is_torch_version |
| 52 | +from optimum.intel.utils.import_utils import ( |
| 53 | + is_sentence_transformers_available, |
| 54 | + is_torch_version, |
| 55 | + _langchain_available, |
| 56 | + _langchain_version, |
| 57 | +) |
51 | 58 |
|
52 | 59 | if is_sentence_transformers_available():
|
53 | 60 | from sentence_transformers import SentenceTransformer
|
@@ -707,3 +714,35 @@ def test_sentence_transformers_save_and_infer(self, model_arch):
|
707 | 714 | model = IPEXSentenceTransformer(tmpdirname, model_kwargs={"subfolder": "ipex"})
|
708 | 715 | sentences = ["This is an example sentence", "Each sentence is converted"]
|
709 | 716 | model.encode(sentences)
|
| 717 | + |
| 718 | + |
| 719 | +class IPEXLangchainTest(unittest.TestCase): |
| 720 | + SUPPORTED_ARCHITECTURES = ("gpt2",) |
| 721 | + |
| 722 | + @unittest.skipIf( |
| 723 | + not _langchain_available or version.parse(_langchain_version) <= version.parse("0.3.30"), |
| 724 | + reason="Unsupported langchain", |
| 725 | + ) |
| 726 | + @parameterized.expand(SUPPORTED_ARCHITECTURES) |
| 727 | + def test_huggingface_pipeline_streaming(self, model_arch): |
| 728 | + from langchain_huggingface import HuggingFacePipeline |
| 729 | + |
| 730 | + model_id = MODEL_NAMES[model_arch] |
| 731 | + |
| 732 | + hf_pipe = HuggingFacePipeline.from_model_id( |
| 733 | + model_id=model_id, |
| 734 | + task="text-generation", |
| 735 | + pipeline_kwargs={"max_new_tokens": 10}, |
| 736 | + backend="ipex", |
| 737 | + ) |
| 738 | + |
| 739 | + generator = hf_pipe.stream("Q: How do you say 'hello' in German? A:'", stop=["."]) |
| 740 | + |
| 741 | + self.assertIsInstance(generator, Generator) |
| 742 | + |
| 743 | + stream_results_string = "" |
| 744 | + for chunk in generator: |
| 745 | + self.assertIsInstance(chunk, str) |
| 746 | + stream_results_string = chunk |
| 747 | + |
| 748 | + self.assertTrue(len(stream_results_string.strip()) > 1) |
0 commit comments