Skip to content

Commit dd5cdce

Browse files
committed
add embeddings
1 parent ebc2296 commit dd5cdce

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

tests/ipex/test_modeling.py

+19
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,25 @@ def test_sentence_transformers_save_and_infer(self, model_arch):
715715
sentences = ["This is an example sentence", "Each sentence is converted"]
716716
model.encode(sentences)
717717

718+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
719+
@require_sentence_transformers
720+
@unittest.skipIf(
721+
not _langchain_hf_available or version.parse(_langchain_hf_version) <= version.parse("0.1.2"),
722+
reason="Unsupported langchain version",
723+
)
724+
def test_langchain(self, model_arch):
725+
from langchain_huggingface import HuggingFaceEmbeddings
726+
727+
model_id = MODEL_NAMES[model_arch]
728+
model_kwargs = {"device": "cpu", "backend": "ipex"}
729+
730+
embedding = HuggingFaceEmbeddings(
731+
model_name=model_id,
732+
model_kwargs=model_kwargs,
733+
)
734+
output = embedding.embed_query("foo bar")
735+
self.assertTrue(len(output) > 0)
736+
718737

719738
class IPEXLangchainTest(unittest.TestCase):
720739
SUPPORTED_ARCHITECTURES = ("gpt2",)

tests/openvino/test_modeling.py

+18
Original file line numberDiff line numberDiff line change
@@ -2803,6 +2803,24 @@ def test_sentence_transformers_save_and_infer(self, model_arch):
28032803
model.encode(sentences)
28042804
gc.collect()
28052805

2806+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
2807+
@unittest.skipIf(
2808+
not _langchain_hf_available or version.parse(_langchain_hf_version) < version.parse("0.1.2"),
2809+
reason="Unsupported langchain version",
2810+
)
2811+
def test_langchain(self, model_arch):
2812+
from langchain_huggingface import HuggingFaceEmbeddings
2813+
2814+
model_id = MODEL_NAMES[model_arch]
2815+
model_kwargs = {"device": "cpu", "backend": "openvino"}
2816+
2817+
embedding = HuggingFaceEmbeddings(
2818+
model_name=model_id,
2819+
model_kwargs=model_kwargs,
2820+
)
2821+
output = embedding.embed_query("foo bar")
2822+
self.assertTrue(len(output) > 0)
2823+
28062824

28072825
class OVLangchainTest(unittest.TestCase):
28082826
SUPPORTED_ARCHITECTURES = ("gpt2",)

0 commit comments

Comments
 (0)