Skip to content
This repository was archived by the owner on Aug 28, 2023. It is now read-only.

Commit 8aec16c

Browse files
authored
[81218] [NLP] Progress tracking (#17)
1 parent 9b096b9 commit 8aec16c

File tree

11 files changed

+177
-12
lines changed

11 files changed

+177
-12
lines changed

automation/bom/image_BOM.txt

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ wb/error/parse_env_error.py
153153
wb/error/request_error.py
154154
wb/error/sanitize_parameter_error.py
155155
wb/error/ssh_client_error.py
156+
wb/error/transformers_onnx_conversion_error_map.json
156157
wb/extensions_factories/__init__.py
157158
wb/extensions_factories/celery.py
158159
wb/extensions_factories/database.py

config/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
JUPYTER_CELL_TEMPLATES_FOLDER = os.path.join(ROOT_FOLDER, 'wb', 'main', 'jupyter_notebooks', 'cell_templates')
5151
UPLOAD_FOLDER_MODELS = os.path.join(ESSENTIAL_DATA_FOLDER, 'models')
5252
CONSOLE_TOOL_WRAPPER_FOLDER = os.path.join(ROOT_FOLDER, 'wb', 'main', 'console_tool_wrapper')
53+
TRANSFORMERS_ONNX_ERROR_MAP_JSON = Path(ROOT_FOLDER) / 'wb' / 'error' / 'transformers_onnx_conversion_error_map.json'
5354
VOC_IMAGES_FOLDER = 'JPEGImages'
5455
VOC_ANNOTATIONS_FOLDER = 'Annotations'
5556
VOC_IMAGESETS_FOLDER = 'ImageSets'

wb/error/code_registry.py

+5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class CodeRegistry:
4646
'DEPLOYMENT_MANAGER_ERROR': 4008,
4747
'DATUMARO_ERROR': 4009,
4848
'RESHAPE_MODEL_ERROR': 4010,
49+
'TRANSFORMERS_ONNX_ERROR': 4020,
4950

5051
# DATABASE ERRORS
5152
'DATABASE_ERROR': 5001,
@@ -160,3 +161,7 @@ def get_dev_cloud_remote_job_error_code(cls):
160161
@classmethod
161162
def get_reshape_model_error_code(cls):
162163
return cls.CODES['RESHAPE_MODEL_ERROR']
164+
165+
@classmethod
166+
def get_transformers_onnx_error_code(cls):
167+
return cls.CODES['TRANSFORMERS_ONNX_ERROR']

wb/error/job_error.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17-
from config.constants import CELERY_RETRY_COUNTDOWN, CELERY_RETRY_MAX_RETRY
17+
import json
18+
19+
from config.constants import CELERY_RETRY_COUNTDOWN, CELERY_RETRY_MAX_RETRY, TRANSFORMERS_ONNX_ERROR_MAP_JSON
1820
from wb.error.general_error import GeneralError
1921
from wb.error.code_registry import CodeRegistry
2022
from wb.main.enumerates import RemoteSetupStatusMessagesEnum, RemoteSetupStatusCodeEnum
@@ -103,3 +105,21 @@ def __init__(self, error_type, job_id):
103105
RemoteSetupStatusMessagesEnum.PIP_VERSION_ERROR.value:
104106
RemoteSetupStatusCodeEnum.PIP_VERSION_ERROR.value,
105107
}
108+
109+
110+
class TransformersONNXConversionError(JobGeneralError):
111+
code = CodeRegistry.get_transformers_onnx_error_code()
112+
113+
with open(TRANSFORMERS_ONNX_ERROR_MAP_JSON) as f:
114+
message_map = json.load(f)
115+
116+
def __init__(self, message: str, job_id: int):
117+
message = self.replace_error_message(message)
118+
super().__init__(message, job_id)
119+
120+
def replace_error_message(self, message: str) -> str:
121+
for substring, replacement_string in self.message_map.items():
122+
if substring in message:
123+
return replacement_string
124+
125+
return f'Unexpected error: {message}'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"RuntimeError: 0INTERNAL ASSERT FAILED": "PyTorch JiT trace error",
3+
"onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument": "Wrong tokenizer type in the repository. To convert the model you could create a repository with the same model and right tokenizer type and use it instead.",
4+
"Connection error, and we cannot find the requested files in the cached path": "Connection error. Check the internet connection.",
5+
"TypeError: not a string": "Cannot initialize tokenizer from the repository.",
6+
"Error(s) in loading state_dict": "Cannot initialize the the model from the repository. Try to connect the repository creator of the repository.",
7+
"sequence item 0: expected str instance": "[unk] and [pad] tokens for tokenizer are not set.",
8+
"expected str, bytes or os.PathLike object": "Not enough files for tokenizer initialization in the repository",
9+
"path should be string, bytes, os.PathLike or integer": "Not enough files for tokenizer initialization in the repository",
10+
"The state dictionary of the model you are training to load is corrupted": "Cannot initialize the model form the repository. It may not have been saved properly.",
11+
"No such file or directory (os error 2)": "Cannot initialize tokenizer from the repository.",
12+
"Can't load tokenizer for": "Cannot initialize tokenizer from the repository.",
13+
"Exporting model exceed maximum protobuf size of 2GB": "Cannot convert a large model to ONNX.",
14+
"Connection error": "Connection error, check the internet connection.",
15+
"Model and config inputs doesn't match": "Cannot convert model to ONNX - model and config inputs doesn't match. The repository might contain wrong tokenizer type.",
16+
"Wrong index found for [MASK]": "Cannot initialize tokenizer from the model repository.",
17+
"JSONDecodeError": "Cannot initialize tokenizer from the model repository - the json file is corrupted.",
18+
"PreValidation Error": "Cannot initialize tokenizer or model from the repository."
19+
}

wb/main/console_tool_wrapper/huggingface_model_downloader/tool.py

+91-5
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,22 @@
1414
See the License for the specific language governing permissions and
1515
limitations under the License.
1616
"""
17+
18+
import re
1719
from pathlib import Path
1820

1921
from wb.main.console_tool_wrapper.console_parameter_validator import ConsoleParametersTypes
2022
from wb.main.console_tool_wrapper.python_console_tool import PythonModuleTool
21-
from wb.main.jobs.tools_runner.console_output_parser import ConsoleToolOutputParser
23+
from wb.main.jobs.interfaces.job_state import JobStateSubject
24+
from wb.main.jobs.tools_runner.console_output_parser import ConsoleToolOutputParser, skip_empty_line_decorator
25+
26+
27+
DOWNLOAD_PROGRESS_STRING_START = 'Downloading:'
28+
CONVERSION_START = 'Using framework PyTorch'
29+
MODEL_SAVED = 'All good, model saved'
30+
TOLERANCE_CHECK_FAILED = "Outputs values doesn't match between reference model and ONNX exported model"
31+
VALIDATING_ONNX_MODEL = 'Validating ONNX model'
32+
NOT_ALL_WEIGHTS_USED = 'Some weights of the model checkpoint'
2233

2334

2435
class HuggingfaceModelDownloaderTool(PythonModuleTool):
@@ -39,9 +50,84 @@ def __init__(self, python_exec: Path, model_id: str, onnx_model_path: Path):
3950

4051

4152
class HuggingfaceModelDownloaderParser(ConsoleToolOutputParser):
42-
def __init__(self):
43-
super().__init__()
53+
def __init__(self, job_state_subject: JobStateSubject):
54+
super().__init__(job_state_subject=job_state_subject)
55+
self.current_pct = 0
56+
self.current_step = 0
57+
58+
self.download_steps = 5
59+
self.pct_per_step = round(60 / self.download_steps)
60+
61+
self.current = re.compile(r'(\d+\.?\d*)[Mk]?/\d+\.?\d*')
62+
self.total = re.compile(r'\d+\.?\d*[Mk]?/(\d+\.?\d*)')
63+
64+
self.downloaded = False
65+
self.downloaded_pct = 60
66+
self.converted = False
67+
self.converted_pct = 80
4468

69+
self.error = False
70+
71+
@skip_empty_line_decorator
4572
def parse(self, string: str):
46-
# todo: implement progress reporting
47-
print(string)
73+
if self.error:
74+
return
75+
76+
string = string.strip()
77+
78+
# skip tensorflow error message
79+
if 'error' in string.lower() and 'tensorflow' not in string.lower():
80+
self.error = True
81+
self._job_state_subject.update_state(progress=100)
82+
return
83+
84+
if not self.downloaded:
85+
if string.startswith(DOWNLOAD_PROGRESS_STRING_START):
86+
self.parse_download_stage(string)
87+
elif string.startswith(CONVERSION_START):
88+
self.current_pct = self.downloaded_pct
89+
self.downloaded = True
90+
self.parse_convert_stage(string)
91+
elif not self.converted:
92+
self.parse_convert_stage(string)
93+
else:
94+
self.parse_validation_stage(string)
95+
96+
self._job_state_subject.update_state(progress=self.current_pct)
97+
98+
def parse_download_stage(self, string: str):
99+
current_size_match = self.current.search(string)
100+
total_size_match = self.total.search(string)
101+
102+
if not current_size_match or not total_size_match:
103+
return
104+
105+
current_size = float(current_size_match.group(1))
106+
total_size = float(total_size_match.group(1))
107+
108+
ratio = 0 if current_size > total_size else current_size / total_size
109+
110+
self.current_pct = max(
111+
self.current_step * self.pct_per_step + round(ratio * self.pct_per_step),
112+
self.current_pct
113+
)
114+
115+
self.current_step += (current_size == total_size)
116+
117+
def parse_convert_stage(self, string: str) -> None:
118+
self.current_pct = min(self.current_pct + 1, 100)
119+
120+
if NOT_ALL_WEIGHTS_USED in string:
121+
self.error = True
122+
elif VALIDATING_ONNX_MODEL in string:
123+
self.converted = True
124+
self.current_pct = self.converted_pct
125+
self.parse_validation_stage(string)
126+
127+
def parse_validation_stage(self, string: str) -> None:
128+
self.current_pct = min(self.current_pct + 4, 100)
129+
130+
if TOLERANCE_CHECK_FAILED in string:
131+
self.error = True
132+
elif MODEL_SAVED in string:
133+
self.current_pct = 100

wb/main/console_tool_wrapper/model_downloader/console_output_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from wb.main.jobs.tools_runner.console_output_parser import ConsoleToolOutputParser, skip_empty_line_decorator
2323

2424

25-
class DownloadingFile():
25+
class DownloadingFile:
2626
def __init__(self, name: str, size: float):
2727
self.name = name
2828
self.size = size

wb/main/console_tool_wrapper/model_optimizer/console_output_parser.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@ class ModelOptimizerParser(ConsoleToolOutputParser):
2525
def __init__(self, job_state_subject: ModelOptimizerJobStateSubject):
2626
super().__init__(job_state_subject=job_state_subject)
2727
self._progress_pattern = re.compile(r'.*Progress: \[.*]\s*(?P<progress>\d+(.\d+)?)%\sdone$')
28+
self.update_every_pct = 5
2829

2930
def parse(self, string: str):
3031
progress_match = self._progress_pattern.search(string)
3132
if progress_match:
33+
if self._job_state_subject.job_progress is None:
34+
self._job_state_subject.update_state(progress=0)
35+
3236
percent = float(progress_match.group('progress'))
33-
self._job_state_subject.update_state(progress=percent)
37+
38+
if percent - self._job_state_subject.job_progress > self.update_every_pct:
39+
self._job_state_subject.update_state(progress=percent)

wb/main/huggingface_api/huggingface_api.py

+11
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def json(self):
8989
}
9090

9191

92+
contains_decoder = {
93+
model_type for model_type, tasks in FeaturesManager._SUPPORTED_MODEL_TYPE.items()
94+
if any("with-past" in task for task in tasks)
95+
}
96+
97+
9298
def _validate_hf_model(model: ModelInfo) -> ValidationResult:
9399
if not model.config:
94100
return ValidationResult(disabled=True, message='Model has no config')
@@ -102,6 +108,11 @@ def _validate_hf_model(model: ModelInfo) -> ValidationResult:
102108
disabled=True,
103109
message=f'Sequence classification feature is not supported for model type {model_type}'
104110
)
111+
if model_type in contains_decoder:
112+
return ValidationResult(
113+
disabled=True,
114+
message=f'The model type {model_type} contains transformer decoder and is not supported by DL Workbench'
115+
)
105116
return ValidationResult(disabled=False)
106117

107118

wb/main/jobs/models/import_huggingface_model_job.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from pathlib import Path
1919

2020
from sqlalchemy.orm import Session
21+
from transformers import AutoTokenizer, PretrainedConfig
2122

22-
from wb.error.job_error import ModelOptimizerError
23+
from wb.error.job_error import TransformersONNXConversionError
2324
from wb.extensions_factories.database import get_db_session_for_celery
2425
from wb.main.console_tool_wrapper.huggingface_model_downloader.tool import (HuggingfaceModelDownloaderTool,
2526
HuggingfaceModelDownloaderParser)
@@ -52,6 +53,8 @@ def run(self):
5253

5354
topology: TopologiesModel = import_job.model
5455

56+
self.pre_validation(import_job.huggingface_model_id)
57+
5558
environment: EnvironmentModel = topology.environment
5659
python_executable = environment.python_executable
5760

@@ -61,14 +64,27 @@ def run(self):
6164
onnx_model_path=Path(topology.path),
6265
)
6366

64-
parser = HuggingfaceModelDownloaderParser()
67+
parser = HuggingfaceModelDownloaderParser(self._job_state_subject)
6568
runner = LocalRunner(tool, parser)
6669

6770
return_code, message = runner.run_console_tool(self)
6871

6972
if return_code:
7073
self._job_state_subject.update_state(status=StatusEnum.error, error_message='error')
71-
raise ModelOptimizerError(message, self.job_id)
74+
raise TransformersONNXConversionError(message, self.job_id)
7275

7376
self._job_state_subject.update_state(progress=100, status=StatusEnum.ready)
7477
self._job_state_subject.detach_all_observers()
78+
79+
def pre_validation(self, huggingface_model_id: str) -> None:
80+
"""
81+
Check that the tokenizer and model config can be initialized from the repository before loading a model
82+
"""
83+
try:
84+
AutoTokenizer.from_pretrained(huggingface_model_id)
85+
PretrainedConfig.from_pretrained(huggingface_model_id)
86+
except Exception:
87+
self._job_state_subject.update_state(status=StatusEnum.error, error_message='error')
88+
raise TransformersONNXConversionError(
89+
'PreValidation Error', self.job_id
90+
)

wb/main/utils/tokenizer/tokeinzer_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, tokenizer_folder: Path, tokenizer_type: Optional[TokenizerTyp
4141
)
4242

4343
@classmethod
44-
def from_model(cls, tokenizer_model: TokenizerModel) -> "TokenizerWrapper":
44+
def from_model(cls, tokenizer_model: TokenizerModel) -> 'TokenizerWrapper':
4545
return cls(
4646
tokenizer_folder=Path(tokenizer_model.path),
4747
tokenizer_type=tokenizer_model.tokenizer_type,

0 commit comments

Comments
 (0)