Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and Add Tests #2

Open
wants to merge 76 commits into
base: openvino_tokenizers
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8340e1d
Fix model dtype (#502)
jiqing-feng Jan 8, 2024
77f9756
Add ipex inference llama test (#503)
jiqing-feng Jan 8, 2024
03e1fa6
Disable marian test until openvino next release (#504)
echarlaix Jan 8, 2024
c64025d
Add INC modeling position_ids generation (#456)
jiqing-feng Jan 8, 2024
aa5b71b
Add f32 precision for compare-with-transformers tests (#508)
helena-intel Jan 10, 2024
ba2487b
Fix typo inside InferRequestWrapper
nikita-savelyevv Jan 11, 2024
23f4f5d
Merge pull request #511 from nikita-savelyevv/fix-infer-request-wrapp…
AlexKoff88 Jan 12, 2024
3f7551e
Add try for get_property (#510)
wgzintel Jan 12, 2024
545ad5a
Fix error with optimum-cli export openvino --help
helena-intel Jan 13, 2024
3c196c3
Merge pull request #514 from huggingface/helena/optimum-cli-fix
AlexKoff88 Jan 15, 2024
133aa7d
Bump min torch version (#515)
echarlaix Jan 15, 2024
7f236c2
Add OPenVINO stateful model support (#493)
eaidova Jan 16, 2024
2f2a764
Add openvino-nightly to automated tests (#506)
helena-intel Jan 16, 2024
e22a2ac
Fix loading Timm models with ov_config (#517)
helena-intel Jan 16, 2024
76ce9de
Use f32 inference for some OpenVINO stable diffusion/training tests (…
helena-intel Jan 17, 2024
94bc226
Convert tokenizers with openvino_tokenizers
slyalin Jan 5, 2024
6bb395f
Update optimum/exporters/openvino/__main__.py
slyalin Jan 5, 2024
7d16ec7
Refactor and Add Tests
apaniukov Jan 9, 2024
f0933ad
Fix t5 Test
apaniukov Jan 10, 2024
24cc616
Add Warning
apaniukov Jan 10, 2024
49337b0
Return Tests
apaniukov Jan 10, 2024
7709043
Move export_tokenizer to convert.py
apaniukov Jan 10, 2024
dbd609b
Avoid Double Tokenizer Save
apaniukov Jan 12, 2024
7e24f10
Fix Style
apaniukov Jan 12, 2024
2cf460d
Refactor After Review
apaniukov Jan 18, 2024
57782d1
Skip Tokenizers Tests If No Package Installed
apaniukov Jan 18, 2024
db668da
Merge branch 'main' into openvino_tokenizers
apaniukov Jan 18, 2024
e00c5fc
Return the original forward after exporting the PyTorch model to Open…
alexsu52 Jan 18, 2024
cbb5fde
Fix Inference docs (#525)
ngaloppo Jan 19, 2024
e7cd70f
Style Fix
apaniukov Jan 19, 2024
40cf117
Fix OV Tokenizers Check
apaniukov Jan 19, 2024
4652ae4
Add warning message when using transformers < 4.35 (#524)
helena-intel Jan 19, 2024
901f48a
Fix Tests
apaniukov Jan 19, 2024
ff1e382
Fix compatibility with transformers (#527)
echarlaix Jan 19, 2024
69140bf
Fix quantization tests for opeenvino-nightly (#523)
eaidova Jan 22, 2024
af2e986
split create pkv to a function (#521)
jiqing-feng Jan 22, 2024
ac640ed
Add test for _print_compiled_model_properties (#528)
helena-intel Jan 23, 2024
672c022
Enable automatic CACHE_DIR for GPU inference only (#520)
helena-intel Jan 23, 2024
1a76bd4
Add Missing return
apaniukov Jan 23, 2024
f2b2237
Turn off tokenizer message if not installed
apaniukov Jan 23, 2024
5e9c1b7
Update OpenVINO documentation about weight compression (#529)
AlexKoff88 Jan 24, 2024
3066ade
Merge branch 'main' into openvino-tokenizers
apaniukov Jan 24, 2024
d96ebfa
Fix ov device (#530)
echarlaix Jan 25, 2024
e0c1143
Fix expected quantization matmul test (#531)
echarlaix Jan 25, 2024
71610dd
Dev version
echarlaix Jan 25, 2024
a622f4d
Fix OVCasualLM model inference without generate (#532)
eaidova Jan 26, 2024
805e737
Add IPEX models (#516)
echarlaix Jan 26, 2024
87b36db
Add IPEX model for question answering (#534)
echarlaix Jan 26, 2024
6bf5fbc
Expose InferRequestWrapper class so it can be imported from elsewhere…
nikita-savelyevv Jan 29, 2024
9a2e271
Merge branch 'main' into openvino-tokenizers
apaniukov Jan 29, 2024
7ee347e
Move tokenizers to OV dependencies
apaniukov Jan 29, 2024
6e79be1
Add IPEX models for audio and image classification tasks (#536)
echarlaix Jan 29, 2024
20df723
relax requirements to have registered normalized config for usage con…
eaidova Jan 30, 2024
8d2ec41
Check OV Compatibility
apaniukov Jan 30, 2024
1b5c3cb
IPEX decoder model fix (#539)
echarlaix Jan 30, 2024
3b627f4
Enable loading of torchscript model with INC and add warning (#540)
echarlaix Jan 30, 2024
32a7274
Bump OV Version
apaniukov Jan 30, 2024
a251422
Fix torch version for ipex tests (#545)
echarlaix Jan 31, 2024
398450d
Refactor IPEX CausalLM for better model architecture scale (#544)
ofirzaf Jan 31, 2024
8ee487d
Automatic `torch.autocast` for IPEXModel (#542)
ofirzaf Jan 31, 2024
788e458
Add an initial warmup step to `IPEXModel`s (#543)
ofirzaf Jan 31, 2024
0ca9447
Fix format (#546)
echarlaix Jan 31, 2024
552de65
Dev version
echarlaix Jan 31, 2024
8c029e0
Move OpenVINO Tokenizers To Optional Dependencies
apaniukov Feb 1, 2024
7ea3656
Fix OV pre-commit test
daniil-lyakhov Feb 1, 2024
24f40bf
CUSTOMIZED_QUANTIZATION_CONFIG is updated
daniil-lyakhov Feb 2, 2024
5120f75
Merge pull request #548 from daniil-lyakhov/dl/fix_ov_precommit
AlexKoff88 Feb 5, 2024
0f45751
Update README (#549)
echarlaix Feb 5, 2024
ad99b98
Add bloom ipex inference test (#551)
echarlaix Feb 6, 2024
24a1e30
Remove pytorch v2.1.2 constraint for tests since ipex v2.2.0 release…
echarlaix Feb 6, 2024
e40e627
Fix openvino export model from ONNX (#554)
echarlaix Feb 7, 2024
09b067f
Add --convert-tokenizer Option to CLI
apaniukov Feb 8, 2024
a7b766e
Add load_in_4bit option for OVModelForCausalLM (#538)
AlexKoff88 Feb 8, 2024
1c14957
Skip automodel compression weights tests for nncf==2.8.0 (#535)
alexsu52 Feb 8, 2024
f3b8ce8
Merge branch 'main' into openvino_tokenizers
apaniukov Feb 8, 2024
3c27fbd
Fix SD Tokenizer
apaniukov Feb 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
python-version: [3.8, 3.11]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -32,13 +32,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[openvino,nncf,tests,diffusers]
pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
- name: Test with Pytest
run: |
pytest tests/openvino/ --ignore test_modeling_basic
- name: Test openvino-nightly import
run: |
pip uninstall -y openvino
pip install openvino-nightly
python -c "from optimum.intel import OVModelForCausalLM; OVModelForCausalLM.from_pretrained('hf-internal-testing/tiny-random-gpt2', export=True, compile=False)"

9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

🤗 Optimum Intel is the interface between the 🤗 Transformers and Diffusers libraries and the different tools and libraries provided by Intel to accelerate end-to-end pipelines on Intel architectures.

[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) is an open-source library which provides optimizations for both eager mode and graph mode, however, compared to eager mode, graph mode in PyTorch* normally yields better performance from optimization techniques, such as operation fusion.

Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) is an open-source library enabling the usage of the most popular compression techniques such as quantization, pruning and knowledge distillation. It supports automatic accuracy-driven tuning strategies in order for users to easily generate quantized model. The users can easily apply static, dynamic and aware-training quantization approaches while giving an expected accuracy criteria. It also supports different weight pruning techniques enabling the creation of pruned model giving a predefined sparsity target.

[OpenVINO](https://docs.openvino.ai/latest/index.html) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
Expand All @@ -19,6 +21,7 @@ To install the latest release of 🤗 Optimum Intel with the corresponding requi
|:-----------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------|
| [Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) | `pip install --upgrade-strategy eager "optimum[neural-compressor]"` |
| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `pip install --upgrade-strategy eager "optimum[openvino,nncf]"` |
| [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) | `pip install --upgrade-strategy eager "optimum[ipex]"` |

The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.

Expand All @@ -37,7 +40,7 @@ or to install from source including dependencies:
python -m pip install "optimum-intel[extras]"@git+https://github.com/huggingface/optimum-intel.git
```

where `extras` can be one or more of `neural-compressor`, `openvino`, `nncf`.
where `extras` can be one or more of `ipex`, `neural-compressor`, `openvino`, `nncf`.

# Quick tour

Expand Down Expand Up @@ -75,10 +78,10 @@ It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2
optimum-cli export openvino --model gpt2 ov_model
```

If you add `--int8`, the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.
You can also apply 8-bit weight-only quantization when exporting your model : the model linear and embedding weights will be quantized to INT8, the activations will be kept in floating point precision.

```plain
optimum-cli export openvino --model gpt2 --int8 ov_model
optimum-cli export openvino --model gpt2 --weight-format int8 ov_model
```

To apply quantization on both weights and activations, you can find more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov).
Expand Down
40 changes: 27 additions & 13 deletions docs/source/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ optimum-cli export openvino --model local_path --task text-generation-with-past
Once the model is exported, you can load the OpenVINO model using :

```python
from optimum.intel import AutoModelForCausalLM
from optimum.intel import OVModelForCausalLM

model_id = "helenai/gpt2-ov"
model = AutoModelForCausalLM.from_pretrained(model_id)
model_id = "ov_model"
model = OVModelForCausalLM.from_pretrained(model_id)
```

You can also load your PyTorch checkpoint and convert it to the OpenVINO format on-the-fly, by setting `export=True` when loading your model.

```python
from optimum.intel import AutoModelForCausalLM
from optimum.intel import OVModelForCausalLM

model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id, export=True)
model = OVModelForCausalLM.from_pretrained(model_id, export=True)
model.save_pretrained("ov_model")
```

Expand Down Expand Up @@ -94,15 +94,15 @@ model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
```

### Weight only quantization
### Weight-only quantization

You can also apply INT8 quantization on your models weights when exporting your model with the CLI:
You can also apply 8-bit or 4-bit weight quantization when exporting your model with the CLI:

```bash
optimum-cli export openvino --model gpt2 --int8 ov_model
optimum-cli export openvino --model gpt2 --weight-format int8 ov_model
```

This will results in the exported model linear and embedding layers to be quantized to INT8, the activations will be kept in floating point precision.
This will result in the exported model linear and embedding layers to be quantized to INT8 or INT4, the activations will be kept in floating point precision. This type of optimization allows reducing the footprint and latency of LLMs.

This can also be done when loading your model by setting the `load_in_8bit` argument when calling the `from_pretrained()` method.

Expand All @@ -112,6 +112,21 @@ from optimum.intel import OVModelForCausalLM
model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```

> **NOTE:** `load_in_8bit` is enabled by default for the models larger than 1 billion parameters.

There are also alternative compression options for a different performance-accuracy trade-off:

| Option | Description |
|---------------------------------------------------------------------|-------------------|
| `fp16` | Float16 weights |
| `int8` | INT8 weights |
| `int4_sym_g128`, `int4_asym_g128`, `int4_sym_g64`, `int4_asym_g64`* | INT4 weights |

*`sym` and `asym` stand for symmetric and asymmetric quantization, `g128` and `g64` means the group size `128` and `64` respectively.

`--ratio` CLI parameter controls the ratio between 4-bit and 8-bit quantized layers and can also change performance-accuracy trade-off for the optimized model. It is valid only for INT4 quantization options.


To apply quantization on both weights and activations, you can use the `OVQuantizer`, more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov#optimization).

### Static shape
Expand Down Expand Up @@ -186,11 +201,10 @@ It is possible to pass an `ov_config` parameter to `from_pretrained()` with cust
model = OVModelForSequenceClassification.from_pretrained(model_id, ov_config={"INFERENCE_PRECISION_HINT":"f32"})
```

Optimum Intel leverages OpenVINO's model caching to speed up model compiling. By default a `model_cache` directory is created in the model's directory in the [Hugging Face Hub cache](https://huggingface.co/docs/huggingface_hub/main/en/guides/manage-cache). To override this, use the ov_config parameter and set `CACHE_DIR` to a different value. To disable model caching, set `CACHE_DIR` to an empty string.

Optimum Intel leverages OpenVINO's model caching to speed up model compiling on GPU. By default a `model_cache` directory is created in the model's directory in the [Hugging Face Hub cache](https://huggingface.co/docs/huggingface_hub/main/en/guides/manage-cache). To override this, use the ov_config parameter and set `CACHE_DIR` to a different value. To disable model caching on GPU, set `CACHE_DIR` to an empty string.

```python
model = OVModelForSequenceClassification.from_pretrained(model_id, ov_config={"CACHE_DIR":""})
model = OVModelForSequenceClassification.from_pretrained(model_id, device="GPU", ov_config={"PERFORMANCE_HINT": "LATENCY", "CACHE_DIR":""})
```

### Sequence-to-sequence models
Expand Down Expand Up @@ -258,7 +272,7 @@ prompt = "sailing ship in storm by Rembrandt"
images = pipeline(prompt).images
```

To load your PyTorch model and convert it to OpenVINO on-the-fly, you can set `export=True`.
To load your PyTorch model and convert it to OpenVINO on the fly, you can set `export=True`.

```python
model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
30 changes: 30 additions & 0 deletions docs/source/optimization_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@ tokenizer.save_pretrained(save_dir)

The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device.

## Weight-only quantization

You can optimize the performance of text-generation LLMs by quantizing weights to various precisions that provide different performance-accuracy trade-offs.

```python
from optimum.intel import OVModelForCausalLM

model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```

> **NOTE:** `load_in_8bit` is enabled by default for models larger than 1 billion parameters.

For the 4-bit weight quantization we recommend using the NNCF API like below:
```python
from optimum.intel import OVModelForCausalLM
import nncf

model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=False)
model.model = nncf.compress_weights(
model.model,
mode=nncf.CompressWeightsMode.INT4_SYM,
ratio=0.8,
group_size=128,
)
model.save_pretrained("compressed_model")
```

For more details, please refer to the corresponding NNCF [documentation](https://github.com/openvinotoolkit/nncf/blob/develop/docs/compression_algorithms/CompressWeights.md).


## Training-time optimization

Apart from optimizing a model after training like post-training quantization above, `optimum.openvino` also provides optimization methods during training, namely Quantization-Aware Training (QAT) and Joint Pruning, Quantization and Distillation (JPQD).
Expand Down
4 changes: 2 additions & 2 deletions docs/source/reference_inc.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ specific language governing permissions and limitations under the License.

## INCModelForCausalLM

[[autodoc]] neural_compressor.modeling_decoder.INCModelForCausalLM
[[autodoc]] neural_compressor.modeling_base.INCModelForCausalLM

## INCModelForSeq2SeqLM

[[autodoc]] neural_compressor.modeling_base.INCModelForSeq2SeqLM
[[autodoc]] neural_compressor.modeling_base.INCModelForSeq2SeqLM
22 changes: 20 additions & 2 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,25 @@ def parse_args_openvino(parser: "ArgumentParser"):
default=0.8,
help=(
"Compression ratio between primary and backup precision. In the case of INT4, NNCF evaluates layer sensitivity and keeps the most impactful layers in INT8"
"precision (by default 20% in INT8). This helps to achieve better accuracy after weight quantization."
"precision (by default 20%% in INT8). This helps to achieve better accuracy after weight compression."
),
)
optional_group.add_argument(
"--disable-stateful",
action="store_true",
help=(
"Disable stateful converted models, stateless models will be generated instead. Stateful models are produced by default when this key is not used. "
"In stateful models all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. "
"If --disable-stateful option is used, it may result in sub-optimal inference performance. "
"Use it when you intentionally want to use a stateless model, for example, to be compatible with existing "
"OpenVINO native inference code that expects kv-cache inputs and outputs in the model."
),
)
optional_group.add_argument(
"--convert-tokenizer",
action="store_true",
help="Add converted tokenizer and detokenizer with OpenVINO Tokenizers",
)


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -138,6 +154,8 @@ def run(self):
trust_remote_code=self.args.trust_remote_code,
pad_token_id=self.args.pad_token_id,
compression_option=self.args.weight_format,
compression_ratio=self.args.ratio
compression_ratio=self.args.ratio,
stateful=not self.args.disable_stateful,
convert_tokenizer=self.args.convert_tokenizer,
# **input_shapes,
)
1 change: 1 addition & 0 deletions optimum/exporters/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .__main__ import main_export
from .convert import export, export_models, export_pytorch_via_onnx
from .stateful import ensure_stateful_is_available, patch_stateful


__all__ = ["main_export", "export", "export_models"]
44 changes: 37 additions & 7 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
from typing import Any, Callable, Dict, Optional, Union

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, PreTrainedTokenizerBase

from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
from .convert import export_models
from ...intel.utils.import_utils import (
is_nncf_available,
is_openvino_tokenizers_available,
is_optimum_version,
is_transformers_version,
)
from .convert import export_models, export_tokenizer
from .stateful import ensure_export_task_support_stateful


if is_optimum_version(">=", "1.16.0"):
Expand All @@ -40,7 +46,6 @@
]

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_SIZE = 1e9

logger = logging.getLogger(__name__)
Expand All @@ -65,6 +70,8 @@ def main_export(
fn_get_submodels: Optional[Callable] = None,
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = True,
convert_tokenizer: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -124,6 +131,8 @@ def main_export(
`int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression.
compression_ratio (`Optional[float]`, defaults to `None`):
Compression ratio between primary and backup precision (only relevant to INT4).
stateful (`bool`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -277,6 +286,9 @@ class StoreAttr(object):
possible_synonyms = ""
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")

task_support_stateful = ensure_export_task_support_stateful(task)
stateful = stateful and task_support_stateful

preprocessors = maybe_load_preprocessors(
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)
Expand Down Expand Up @@ -311,13 +323,17 @@ class StoreAttr(object):
and getattr(model.config, "pad_token_id", None) is None
and task in ["text-classification"]
)

tokenizer = next(
(preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)), None
)

if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
elif tokenizer is not None:
try:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
model.config.pad_token_id = tok.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
except Exception:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
Expand All @@ -329,6 +345,15 @@ class StoreAttr(object):
generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)

if convert_tokenizer and tokenizer is not None and is_openvino_tokenizers_available():
try:
export_tokenizer(tokenizer, output)
except Exception as exception:
logger.warning(
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
f"models won't be generated. Exception: {exception}"
)

if model.config.is_encoder_decoder and task.startswith("text-generation"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
Expand Down Expand Up @@ -358,10 +383,14 @@ class StoreAttr(object):
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
tokenizer.save_pretrained(output.joinpath("tokenizer"))
if convert_tokenizer and is_openvino_tokenizers_available():
export_tokenizer(tokenizer, output)

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if convert_tokenizer and is_openvino_tokenizers_available():
export_tokenizer(tokenizer, output, suffix="_2")

model.save_config(output)

Expand All @@ -373,6 +402,7 @@ class StoreAttr(object):
device=device,
compression_option=compression_option,
compression_ratio=compression_ratio,
stateful=stateful,
model_kwargs=model_kwargs,
)

Expand Down
Loading