Skip to content

Commit baa36a8

Browse files
committed
Merge branch 'main' into add-eval
2 parents 30072e3 + 72b0630 commit baa36a8

20 files changed

+686
-101
lines changed

.github/workflows/test_openvino.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
python -m pip install --upgrade pip
3333
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
3434
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
35-
pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
35+
pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime
3636
- name: Test with Pytest
3737
run: |
3838
pytest tests/openvino/ --ignore test_modeling_basic

docs/source/optimization_ov.mdx

+11-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,17 @@ from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
8282

8383
model = OVModelForCausalLM.from_pretrained(
8484
model_id,
85-
export=True,
85+
quantization_config=OVWeightQuantizationConfig(bits=4),
86+
)
87+
```
88+
89+
You can tune quantization parameters to achieve a better performance accuracy trade-off as follows:
90+
91+
```python
92+
from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
93+
94+
model = OVModelForCausalLM.from_pretrained(
95+
model_id,
8696
quantization_config=OVWeightQuantizationConfig(bits=4, sym=False, ratio=0.8, dataset="ptb"),
8797
)
8898
```

optimum/commands/export/openvino.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,12 @@ def run(self):
157157
)
158158
self.args.weight_format = "int8"
159159

160-
weight_format = self.args.weight_format or "fp32"
161-
162-
ov_config = None
163-
if weight_format in {"fp16", "fp32"}:
164-
ov_config = OVConfig(dtype=weight_format)
160+
if self.args.weight_format is None:
161+
ov_config = None
162+
elif self.args.weight_format in {"fp16", "fp32"}:
163+
ov_config = OVConfig(dtype=self.args.weight_format)
165164
else:
166-
is_int8 = weight_format == "int8"
165+
is_int8 = self.args.weight_format == "int8"
167166

168167
# For int4 quantization if not parameter is provided, then use the default config if exist
169168
if (
@@ -182,12 +181,12 @@ def run(self):
182181
"group_size": -1 if is_int8 else self.args.group_size,
183182
}
184183

185-
if weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
184+
if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
186185
logger.warning(
187-
f"--weight-format {weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
186+
f"--weight-format {self.args.weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
188187
)
189-
quantization_config["sym"] = "asym" not in weight_format
190-
quantization_config["group_size"] = 128 if "128" in weight_format else 64
188+
quantization_config["sym"] = "asym" not in self.args.weight_format
189+
quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
191190
ov_config = OVConfig(quantization_config=quantization_config)
192191

193192
# TODO : add input shapes

optimum/exporters/ipex/__init__.py

Whitespace-only changes.
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from transformers.models.llama.modeling_llama import (
16+
LlamaAttention,
17+
LlamaDecoderLayer,
18+
LlamaForCausalLM,
19+
LlamaModel,
20+
LlamaRMSNorm,
21+
)
22+
23+
from optimum.intel.utils.import_utils import is_ipex_version
24+
25+
from .modeling_utils import (
26+
_IPEXLlamaDecoderLayerRef,
27+
_llama_attn_forward,
28+
_llama_layer_norm_forward,
29+
_llama_model_forward,
30+
)
31+
32+
33+
_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
34+
_IPEX_EXPORTED_TASK = ("text-generation",)
35+
36+
37+
def convert_func(m, func_name, new_function):
38+
bound_method = new_function.__get__(m, m.__class__)
39+
setattr(m, func_name, bound_method)
40+
41+
42+
def convert_functions(m, target_m, new_function_name, new_function):
43+
for _, sub_m in m.named_children():
44+
if isinstance(sub_m, target_m):
45+
convert_func(sub_m, new_function_name, new_function)
46+
convert_functions(sub_m, target_m, new_function_name, new_function)
47+
48+
49+
def convert_class(m, target_m, new_class, config, distributed=False):
50+
for name, sub_m in m.named_children():
51+
if isinstance(sub_m, target_m):
52+
new_m = new_class(sub_m, config, distributed)
53+
setattr(m, name, new_m)
54+
convert_class(sub_m, target_m, new_class, config, distributed)
55+
56+
57+
def patch_op(m, target_m, new_op_name, new_op):
58+
for name, sub_m in m.named_children():
59+
if isinstance(sub_m, target_m):
60+
setattr(sub_m, new_op_name, new_op)
61+
patch_op(sub_m, target_m, new_op_name, new_op)
62+
63+
64+
def _patch_llama_model(model):
65+
if is_ipex_version("<", "2.5.0"):
66+
raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")
67+
68+
from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding
69+
70+
ipex_rope = RotaryEmbedding(
71+
model.config.max_position_embeddings,
72+
model.config.hidden_size // model.config.num_attention_heads,
73+
model.config.rope_theta,
74+
model.config.architectures[0],
75+
)
76+
ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
77+
patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
78+
patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
79+
80+
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
81+
convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
82+
convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
83+
84+
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
85+
return model
86+
87+
88+
def _patch_model(model):
89+
if isinstance(model, LlamaForCausalLM):
90+
model = _patch_llama_model(model)
91+
return model

0 commit comments

Comments
 (0)