Skip to content

Commit 554e6fe

Browse files
authored
[PT FE] Retry hf hub model load up to 3 times if http error (openvinotoolkit#25648)
### Details: - *item1* - *...* ### Tickets: - *CVS-143832*
1 parent cdf342e commit 554e6fe

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tests/model_hub_tests/models_hub_common/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (C) 2018-2024 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import functools
45
import itertools
56
import os
67
import shutil
@@ -145,3 +146,20 @@ def call_with_timer(timer_label: str, func, args):
145146

146147
def print_stat(s: str, value: float):
147148
print(s.format(round_num(value)))
149+
150+
151+
def retry(max_retries=3, exceptions=(Exception,), delay=None):
152+
def retry_decorator(func):
153+
@functools.wraps(func)
154+
def wrapper(*args, **kwargs):
155+
for attempt in range(max_retries):
156+
try:
157+
return func(*args, **kwargs)
158+
except exceptions as e:
159+
print(f"Attempt {attempt + 1} of {max_retries} failed: {e}")
160+
if attempt < max_retries - 1 and delay is not None:
161+
time.sleep(delay)
162+
else:
163+
raise e
164+
return wrapper
165+
return retry_decorator

tests/model_hub_tests/pytorch/test_hf_transformers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import pytest
77
import torch
88
from huggingface_hub import model_info
9+
from huggingface_hub.utils import HfHubHTTPError
910
from models_hub_common.constants import hf_hub_cache_dir
10-
from models_hub_common.utils import cleanup_dir
11+
from models_hub_common.utils import cleanup_dir, retry
1112
import transformers
1213
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, AutoFeatureExtractor, AutoModelForTextToWaveform, \
1314
CLIPFeatureExtractor, XCLIPVisionModel, T5Tokenizer, VisionEncoderDecoderModel, ViTImageProcessor, BlipProcessor, BlipForConditionalGeneration, \
@@ -101,6 +102,7 @@ def setup_class(self):
101102
self.image = Image.open(requests.get(url, stream=True).raw)
102103
self.cuda_available, self.gptq_postinit = None, None
103104

105+
@retry(3, exceptions=(HfHubHTTPError,), delay=1)
104106
def load_model(self, name, type):
105107
name_suffix = ''
106108
if name.find(':') != -1:

0 commit comments

Comments
 (0)