Skip to content

Commit a310706

Browse files
committed
Add langchain test
1 parent 38b6e54 commit a310706

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

tests/openvino/test_modeling.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
import copy
1616
import gc
17+
import importlib
1718
import os
1819
import platform
1920
import tempfile
2021
import time
2122
import unittest
2223
from pathlib import Path
23-
from typing import Dict
24+
from typing import Dict, Generator
2425

2526
import numpy as np
2627
import open_clip
@@ -32,6 +33,7 @@
3233
from datasets import load_dataset
3334
from evaluate import evaluator
3435
from huggingface_hub import HfApi
36+
from packaging import version
3537
from parameterized import parameterized
3638
from PIL import Image
3739
from sentence_transformers import SentenceTransformer
@@ -130,6 +132,15 @@
130132
F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"}
131133

132134

135+
_langchain_available = importlib.util.find_spec("langchain") is not None
136+
_langchain_version = "N/A"
137+
if _langchain_available:
138+
try:
139+
_langchain_version = importlib.metadata.version("langchain")
140+
except importlib.metadata.PackageNotFoundError:
141+
_langchain_available = False
142+
143+
133144
class Timer(object):
134145
def __enter__(self):
135146
self.elapsed = time.perf_counter()
@@ -2796,3 +2807,38 @@ def test_sentence_transformers_save_and_infer(self, model_arch):
27962807
sentences = ["This is an example sentence", "Each sentence is converted"]
27972808
model.encode(sentences)
27982809
gc.collect()
2810+
2811+
2812+
class OVLangchainTest(unittest.TestCase):
2813+
SUPPORTED_ARCHITECTURES = ("gpt2",)
2814+
2815+
@unittest.skipIf(
2816+
not _langchain_available or version.parse(_langchain_version) <= version.parse("0.3.30"),
2817+
reason="Unsupported langchain",
2818+
)
2819+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
2820+
def test_huggingface_pipeline_streaming(self, model_arch):
2821+
from langchain_huggingface import HuggingFacePipeline
2822+
2823+
model_id = MODEL_NAMES[model_arch]
2824+
2825+
hf_pipe = HuggingFacePipeline.from_model_id(
2826+
model_id=model_id,
2827+
task="text-generation",
2828+
pipeline_kwargs={"max_new_tokens": 10},
2829+
backend="openvino",
2830+
)
2831+
2832+
generator = hf_pipe.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
2833+
2834+
self.assertIsInstance(generator, Generator)
2835+
2836+
stream_results_string = ""
2837+
for chunk in generator:
2838+
self.assertIsInstance(chunk, str)
2839+
stream_results_string = chunk
2840+
2841+
self.assertTrue(len(stream_results_string.strip()) > 1)
2842+
2843+
del hf_pipe
2844+
gc.collect()

0 commit comments

Comments
 (0)