Skip to content

Commit cb5208c

Browse files
authored
Add $ to the end of filename regex patterns (#931)
And add an unrelated test for a model with modeling files in a subfolder and configuration files in the root. This passes on main, but failed on 1.19.0
1 parent 41d93a1 commit cb5208c

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

optimum/intel/openvino/modeling_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def from_pretrained(
439439

440440
ov_files = _find_files_matching_pattern(
441441
model_dir,
442-
pattern=r"(.*)?openvino(.*)?\_model.xml",
442+
pattern=r"(.*)?openvino(.*)?\_model.xml$",
443443
subfolder=subfolder,
444444
use_auth_token=token,
445445
revision=revision,

optimum/intel/openvino/modeling_open_clip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def from_pretrained(
152152

153153
ov_files = _find_files_matching_pattern(
154154
model_dir,
155-
pattern=r"(.*)?openvino(.*)?\_model\_(.*)?.xml",
155+
pattern=r"(.*)?openvino(.*)?\_model\_(.*)?.xml$",
156156
subfolder=subfolder,
157157
use_auth_token=token,
158158
revision=revision,

tests/openvino/test_modeling.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_infer_export_when_loading(self):
339339

340340
def test_find_files_matching_pattern(self):
341341
model_id = "echarlaix/tiny-random-PhiForCausalLM"
342-
pattern = r"(.*)?openvino(.*)?\_model.xml"
342+
pattern = r"(.*)?openvino(.*)?\_model.xml$"
343343
# hub model
344344
for revision in ("main", "ov", "itrex"):
345345
ov_files = _find_files_matching_pattern(
@@ -360,7 +360,7 @@ def test_find_files_matching_pattern(self):
360360

361361
@parameterized.expand(("stable-diffusion", "stable-diffusion-openvino"))
362362
def test_find_files_matching_pattern_sd(self, model_arch):
363-
pattern = r"(.*)?openvino(.*)?\_model.xml"
363+
pattern = r"(.*)?openvino(.*)?\_model.xml$"
364364
model_id = MODEL_NAMES[model_arch]
365365
# hub model
366366
ov_files = _find_files_matching_pattern(model_id, pattern=pattern)
@@ -374,6 +374,23 @@ def test_find_files_matching_pattern_sd(self, model_arch):
374374
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern)
375375
self.assertTrue(len(ov_files) > 0 if "openvino" in model_id else len(ov_files) == 0)
376376

377+
@parameterized.expand(("", "openvino"))
378+
def test_find_files_matching_pattern_with_config_in_root(self, subfolder):
379+
# Notably, the model has a config.json file in the root directory and not in the subfolder
380+
model_id = "sentence-transformers-testing/stsb-bert-tiny-openvino"
381+
pattern = r"(.*)?openvino(.*)?\_model.xml$"
382+
# hub model
383+
ov_files = _find_files_matching_pattern(model_id, pattern=pattern, subfolder=subfolder)
384+
self.assertTrue(len(ov_files) == 1 if subfolder == "openvino" else len(ov_files) == 0)
385+
386+
# local model
387+
api = HfApi()
388+
with tempfile.TemporaryDirectory() as tmpdirname:
389+
local_dir = Path(tmpdirname) / "model"
390+
api.snapshot_download(repo_id=model_id, local_dir=local_dir)
391+
ov_files = _find_files_matching_pattern(local_dir, pattern=pattern, subfolder=subfolder)
392+
self.assertTrue(len(ov_files) == 1 if subfolder == "openvino" else len(ov_files) == 0)
393+
377394

378395
class PipelineTest(unittest.TestCase):
379396
def test_load_model_from_hub(self):

0 commit comments

Comments
 (0)