Skip to content

Commit 29b2ac9

Browse files
mvafineaidovanikita-savelyevvecharlaix
authoredDec 23, 2024··
Support AWQ models (#1049)
* Support AWQ models * Add tests * Add dependencies * Fix tests * enable awq export only if ov support it * fix style (#2) * disable awq and gptq install for old torch (#3) * fix style * disable autogptq and autoawq install for old transformers testing * separate common quant models patching and gptq (#4) * disable windows install (#5) * separate common quant models patching and gptq * disable awq windows * skip logits check for quantized models (#6) * fix test after rebase * fix testing condition for 2024.6 and unpatch in case if failed * Fix qwen2-vl tests (#1084) * Skip private mdoel loading test for external contributors (#1082) * Fix reshaping unet if timestep is 0d tensor (#1083) * Disable kv cache compression for fp vlm (#1080) * Support AWQ models * Add tests * Add dependencies * Fix tests * enable awq export only if ov support it * fix style (#2) * disable awq and gptq install for old torch (#3) * fix style * disable autogptq and autoawq install for old transformers testing * separate common quant models patching and gptq (#4) * disable windows install (#5) * separate common quant models patching and gptq * disable awq windows * skip logits check for quantized models (#6) * fix test after rebase * fix testing condition for 2024.6 and unpatch in case if failed * add necessary packages in test_openvino_full * fix code style after rebase (#7) --------- Co-authored-by: eaidova <ekaterina.aidova@intel.com> Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com> Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent ea6fa42 commit 29b2ac9

File tree

7 files changed

+265
-155
lines changed

7 files changed

+265
-155
lines changed
 

‎.github/workflows/test_openvino.yml

+5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ jobs:
5050
name: Install specific dependencies and versions required for older transformers
5151
run: |
5252
pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.* diffusers==0.30.* transformers_stream_generator
53+
54+
- if: ${{ matrix.transformers-version == 'latest' && matrix.test-pattern == '*modeling*'}}
55+
name: Install auto-gptq, autoawq
56+
run: |
57+
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
5358
5459
- if: ${{ matrix.test-pattern == '*modeling*' }}
5560
name: Uninstall NNCF

‎.github/workflows/test_openvino_full.yml

+5
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ jobs:
7878
if: ${{ matrix.transformers-version != 'latest' }}
7979
run: pip install transformers==${{ matrix.transformers-version }}
8080

81+
- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }}
82+
name: Install auto-gptq, autoawq
83+
run: |
84+
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
85+
8186
- name: Pip freeze
8287
run: pip freeze
8388

‎.github/workflows/test_openvino_slow.yml

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ jobs:
4949
name: Install specific dependencies and versions required for older transformers
5050
run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator
5151

52+
- if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }}
53+
name: Install auto-gptq, autoawq
54+
run: |
55+
pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu
56+
5257
- name: Pip freeze
5358
run: pip freeze
5459

‎optimum/exporters/openvino/__main__.py

+150-139
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def main_export(
232232
)
233233

234234
do_gptq_patching = False
235+
do_quant_patching = False
235236
custom_architecture = False
236237
patch_16bit = False
237238
loading_kwargs = model_loading_kwargs or {}
@@ -247,7 +248,11 @@ def main_export(
247248
trust_remote_code=trust_remote_code,
248249
)
249250
quantization_config = getattr(config, "quantization_config", None)
250-
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
251+
supported_quant_methods = ["gptq"]
252+
if is_openvino_version(">=", "2024.6.0"):
253+
supported_quant_methods.append("awq")
254+
do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods
255+
do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq"
251256
model_type = config.model_type.replace("_", "-")
252257
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
253258
custom_architecture = True
@@ -296,7 +301,6 @@ def main_export(
296301
if (
297302
dtype is None
298303
and framework == "pt"
299-
and not do_gptq_patching
300304
and (
301305
task.startswith("text-generation")
302306
or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS
@@ -315,28 +319,28 @@ def main_export(
315319
patch_16bit = True
316320
loading_kwargs["torch_dtype"] = dtype
317321
# Patch the modules to export of GPTQ models w/o GPU
318-
if do_gptq_patching:
319-
torch.set_default_dtype(torch.float32)
322+
if do_quant_patching:
320323
orig_cuda_check = torch.cuda.is_available
321324
torch.cuda.is_available = lambda: True
322325

323-
from optimum.gptq import GPTQQuantizer
326+
if do_gptq_patching:
327+
from optimum.gptq import GPTQQuantizer
324328

325-
orig_post_init_model = GPTQQuantizer.post_init_model
329+
orig_post_init_model = GPTQQuantizer.post_init_model
326330

327-
def post_init_model(self, model):
328-
from auto_gptq import exllama_set_max_input_length
331+
def post_init_model(self, model):
332+
from auto_gptq import exllama_set_max_input_length
329333

330-
class StoreAttr(object):
331-
pass
334+
class StoreAttr(object):
335+
pass
332336

333-
model.quantize_config = StoreAttr()
334-
model.quantize_config.desc_act = self.desc_act
335-
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
336-
model = exllama_set_max_input_length(model, self.max_input_length)
337-
return model
337+
model.quantize_config = StoreAttr()
338+
model.quantize_config.desc_act = self.desc_act
339+
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
340+
model = exllama_set_max_input_length(model, self.max_input_length)
341+
return model
338342

339-
GPTQQuantizer.post_init_model = post_init_model
343+
GPTQQuantizer.post_init_model = post_init_model
340344
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
341345
dtype = deduce_diffusers_dtype(
342346
model_name_or_path,
@@ -351,143 +355,150 @@ class StoreAttr(object):
351355
loading_kwargs["torch_dtype"] = dtype
352356
patch_16bit = True
353357

354-
if library_name == "open_clip":
355-
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
356-
else:
357-
model = TasksManager.get_model_from_task(
358-
task,
359-
model_name_or_path,
360-
subfolder=subfolder,
361-
revision=revision,
362-
cache_dir=cache_dir,
363-
token=token,
364-
local_files_only=local_files_only,
365-
force_download=force_download,
366-
trust_remote_code=trust_remote_code,
367-
framework=framework,
368-
device=device,
369-
library_name=library_name,
370-
**loading_kwargs,
371-
)
358+
try:
359+
if library_name == "open_clip":
360+
model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir)
361+
else:
362+
model = TasksManager.get_model_from_task(
363+
task,
364+
model_name_or_path,
365+
subfolder=subfolder,
366+
revision=revision,
367+
cache_dir=cache_dir,
368+
token=token,
369+
local_files_only=local_files_only,
370+
force_download=force_download,
371+
trust_remote_code=trust_remote_code,
372+
framework=framework,
373+
device=device,
374+
library_name=library_name,
375+
**loading_kwargs,
376+
)
372377

373-
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None
378+
needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None
374379

375-
if needs_pad_token_id:
376-
if pad_token_id is not None:
377-
model.config.pad_token_id = pad_token_id
378-
else:
379-
tok = AutoTokenizer.from_pretrained(model_name_or_path)
380-
pad_token_id = getattr(tok, "pad_token_id", None)
381-
if pad_token_id is None:
382-
raise ValueError(
383-
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
384-
)
385-
model.config.pad_token_id = pad_token_id
380+
if needs_pad_token_id:
381+
if pad_token_id is not None:
382+
model.config.pad_token_id = pad_token_id
383+
else:
384+
tok = AutoTokenizer.from_pretrained(model_name_or_path)
385+
pad_token_id = getattr(tok, "pad_token_id", None)
386+
if pad_token_id is None:
387+
raise ValueError(
388+
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
389+
)
390+
model.config.pad_token_id = pad_token_id
386391

387-
if hasattr(model.config, "export_model_type"):
388-
model_type = model.config.export_model_type.replace("_", "-")
389-
else:
390-
model_type = model.config.model_type.replace("_", "-")
391-
392-
if (
393-
not custom_architecture
394-
and library_name != "diffusers"
395-
and task + "-with-past"
396-
in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name)
397-
):
398-
# Make -with-past the default if --task was not explicitely specified
399-
if original_task == "auto":
400-
task = task + "-with-past"
392+
if hasattr(model.config, "export_model_type"):
393+
model_type = model.config.export_model_type.replace("_", "-")
401394
else:
402-
logger.info(
403-
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
404-
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
395+
model_type = model.config.model_type.replace("_", "-")
396+
397+
if (
398+
not custom_architecture
399+
and library_name != "diffusers"
400+
and task + "-with-past"
401+
in TasksManager.get_supported_tasks_for_model_type(
402+
model_type, exporter="openvino", library_name=library_name
405403
)
404+
):
405+
# Make -with-past the default if --task was not explicitely specified
406+
if original_task == "auto":
407+
task = task + "-with-past"
408+
else:
409+
logger.info(
410+
f"The task `{task}` was manually specified, and past key values will not be reused in the decoding."
411+
f" if needed, please pass `--task {task}-with-past` to export using the past key values."
412+
)
406413

407-
if original_task == "auto":
408-
synonyms_for_task = sorted(TasksManager.synonyms_for_task(task))
409-
if synonyms_for_task:
410-
synonyms_for_task = ", ".join(synonyms_for_task)
411-
possible_synonyms = f" (possible synonyms are: {synonyms_for_task})"
412-
else:
413-
possible_synonyms = ""
414-
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
414+
if original_task == "auto":
415+
synonyms_for_task = sorted(TasksManager.synonyms_for_task(task))
416+
if synonyms_for_task:
417+
synonyms_for_task = ", ".join(synonyms_for_task)
418+
possible_synonyms = f" (possible synonyms are: {synonyms_for_task})"
419+
else:
420+
possible_synonyms = ""
421+
logger.info(f"Automatic task detection to {task}{possible_synonyms}.")
415422

416-
preprocessors = maybe_load_preprocessors(
417-
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
418-
)
423+
preprocessors = maybe_load_preprocessors(
424+
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
425+
)
419426

420-
submodel_paths = export_from_model(
421-
model=model,
422-
output=output,
423-
task=task,
424-
ov_config=ov_config,
425-
stateful=stateful,
426-
model_kwargs=model_kwargs,
427-
custom_export_configs=custom_export_configs,
428-
fn_get_submodels=fn_get_submodels,
429-
preprocessors=preprocessors,
430-
device=device,
431-
trust_remote_code=trust_remote_code,
432-
patch_16bit_model=patch_16bit,
433-
**kwargs_shapes,
434-
)
427+
submodel_paths = export_from_model(
428+
model=model,
429+
output=output,
430+
task=task,
431+
ov_config=ov_config,
432+
stateful=stateful,
433+
model_kwargs=model_kwargs,
434+
custom_export_configs=custom_export_configs,
435+
fn_get_submodels=fn_get_submodels,
436+
preprocessors=preprocessors,
437+
device=device,
438+
trust_remote_code=trust_remote_code,
439+
patch_16bit_model=patch_16bit,
440+
**kwargs_shapes,
441+
)
435442

436-
if convert_tokenizer:
437-
maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task)
438-
439-
clear_class_registry()
440-
del model
441-
gc.collect()
442-
443-
for submodel_path in submodel_paths:
444-
submodel_path = Path(output) / submodel_path
445-
submodel = core.read_model(submodel_path)
446-
447-
quantization_config = None
448-
if ov_config is None:
449-
num_parameters = 0
450-
for op in submodel.get_ops():
451-
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
452-
num_parameters += reduce(operator.mul, op.shape, 1)
453-
del op
454-
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
455-
if is_nncf_available():
456-
quantization_config = {"bits": 8, "sym": False}
457-
logger.info("The model weights will be quantized to int8_asym.")
458-
else:
459-
logger.warning(
460-
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
461-
"requires nncf. Please install it with `pip install nncf`"
462-
)
463-
break
464-
else:
465-
quantization_config = ov_config.quantization_config
466-
if quantization_config is None:
467-
del submodel
468-
gc.collect()
469-
continue
443+
if convert_tokenizer:
444+
maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task)
470445

471-
if not is_nncf_available():
472-
raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`")
446+
clear_class_registry()
447+
del model
448+
gc.collect()
473449

474-
from optimum.intel.openvino.quantization import _weight_only_quantization
450+
for submodel_path in submodel_paths:
451+
submodel_path = Path(output) / submodel_path
452+
submodel = core.read_model(submodel_path)
453+
454+
quantization_config = None
455+
if ov_config is None:
456+
num_parameters = 0
457+
for op in submodel.get_ops():
458+
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
459+
num_parameters += reduce(operator.mul, op.shape, 1)
460+
del op
461+
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
462+
if is_nncf_available():
463+
quantization_config = {"bits": 8, "sym": False}
464+
logger.info("The model weights will be quantized to int8_asym.")
465+
else:
466+
logger.warning(
467+
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
468+
"requires nncf. Please install it with `pip install nncf`"
469+
)
470+
break
471+
else:
472+
quantization_config = ov_config.quantization_config
473+
if quantization_config is None:
474+
del submodel
475+
gc.collect()
476+
continue
477+
478+
if not is_nncf_available():
479+
raise ImportError(
480+
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
481+
)
475482

476-
_weight_only_quantization(submodel, quantization_config)
477-
compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
478-
save_model(submodel, compressed_submodel_path, compress_to_fp16=False)
479-
del submodel
480-
gc.collect()
483+
from optimum.intel.openvino.quantization import _weight_only_quantization
481484

482-
submodel_path.unlink()
483-
submodel_path.with_suffix(".bin").unlink()
484-
compressed_submodel_path.rename(submodel_path)
485-
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))
485+
_weight_only_quantization(submodel, quantization_config)
486+
compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
487+
save_model(submodel, compressed_submodel_path, compress_to_fp16=False)
488+
del submodel
489+
gc.collect()
486490

487-
# Unpatch modules after GPTQ export
488-
if do_gptq_patching:
489-
torch.cuda.is_available = orig_cuda_check
490-
GPTQQuantizer.post_init_model = orig_post_init_model
491+
submodel_path.unlink()
492+
submodel_path.with_suffix(".bin").unlink()
493+
compressed_submodel_path.rename(submodel_path)
494+
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))
495+
496+
finally:
497+
# Unpatch modules after quantized model export
498+
if do_quant_patching:
499+
torch.cuda.is_available = orig_cuda_check
500+
if do_gptq_patching:
501+
GPTQQuantizer.post_init_model = orig_post_init_model
491502

492503

493504
def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):

‎optimum/exporters/openvino/convert.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,11 @@ def ts_patched_forward(*args, **kwargs):
456456
from openvino.frontend.pytorch.patch_model import unpatch_model
457457

458458
unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
459-
model.to(torch.float32)
459+
for m in model.modules():
460+
if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any(
461+
b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False)
462+
):
463+
m.float()
460464

461465
return export_pytorch_via_onnx(
462466
model,

‎tests/openvino/test_modeling.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import gc
1717
import os
18+
import platform
1819
import tempfile
1920
import time
2021
import unittest
@@ -62,7 +63,7 @@
6263
)
6364
from transformers.onnx.utils import get_preprocessor
6465
from transformers.testing_utils import slow
65-
from utils_tests import MODEL_NAMES, TEST_IMAGE_URL
66+
from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference
6667

6768
from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
6869
from optimum.intel import (
@@ -872,7 +873,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
872873
"gpt_neo",
873874
"gpt_neox",
874875
"llama",
875-
# "llama_gptq",
876876
"marian",
877877
"minicpm",
878878
"mistral",
@@ -917,6 +917,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
917917
"minicpm3",
918918
)
919919

920+
# gptq and awq install disabled for windows test environment
921+
if platform.system() != "Windows":
922+
SUPPORTED_ARCHITECTURES += ("opt_gptq",)
923+
924+
# autoawq install disabled for windows test environment
925+
if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows":
926+
SUPPORTED_ARCHITECTURES += ("mixtral_awq",)
927+
920928
GENERATION_LENGTH = 100
921929
REMOTE_CODE_MODELS = (
922930
"chatglm",
@@ -949,9 +957,6 @@ def test_compare_to_transformers(self, model_arch):
949957
if is_openvino_version("<", "2024.1"):
950958
not_stateful.extend(["llama", "gemma", "gpt_bigcode"])
951959

952-
if "gptq" in model_arch:
953-
self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM")
954-
955960
set_seed(SEED)
956961

957962
model_kwargs = {}
@@ -978,20 +983,27 @@ def test_compare_to_transformers(self, model_arch):
978983
if is_stateful:
979984
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)
980985

986+
if "awq" in model_arch or "gptq" in model_arch:
987+
# infer in FP32
988+
model_kwargs["torch_dtype"] = torch.float32
989+
981990
set_seed(SEED)
982-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
991+
with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch):
992+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
983993
if model_arch in ["qwen", "arctic", "glm4"]:
984994
transformers_model.to(torch.float32)
985995

986996
with torch.no_grad():
987-
transformers_outputs = transformers_model(**tokens)
997+
with patch_awq_for_inference("awq" in model_arch):
998+
transformers_outputs = transformers_model(**tokens)
988999

9891000
# Compare tensor outputs
9901001
atol = 1e-3 if model_arch == "minicpm" else 1e-4
991-
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol))
1002+
# quantized models have different logits value range
1003+
if "awq" not in model_arch and "gptq" not in model_arch:
1004+
self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol))
9921005

9931006
# Qwen tokenizer does not support padding
994-
9951007
if model_arch in ["qwen"]:
9961008
return
9971009

@@ -1025,7 +1037,12 @@ def test_compare_to_transformers(self, model_arch):
10251037
from transformers.cache_utils import DynamicCache
10261038

10271039
additional_inputs = {"past_key_values": DynamicCache()}
1028-
transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs)
1040+
with patch_awq_for_inference("awq" in model_arch):
1041+
transformers_outputs = transformers_model.generate(
1042+
**tokens, generation_config=gen_config, **additional_inputs
1043+
)
1044+
print(f"ov_outputs: {ov_outputs}")
1045+
print(f"transformers_outputs: {transformers_outputs}")
10291046
self.assertTrue(
10301047
torch.allclose(ov_outputs, transformers_outputs),
10311048
"OV output {ov_outputs}\nTransformers output {transformers_output}",
@@ -1261,8 +1278,13 @@ def test_beam_search(self, model_arch):
12611278
ov_model_stateless = OVModelForCausalLM.from_pretrained(
12621279
model_id, export=True, use_cache=True, stateful=False, **model_kwargs
12631280
)
1281+
if "awq" in model_arch or "gptq" in model_arch:
1282+
# infer in FP32
1283+
model_kwargs["torch_dtype"] = torch.float32
1284+
12641285
set_seed(SEED)
1265-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
1286+
with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch):
1287+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
12661288

12671289
if model_arch == "arctic":
12681290
transformers_model.to(torch.float32)
@@ -1288,9 +1310,10 @@ def test_beam_search(self, model_arch):
12881310

12891311
if model_arch == "gemma2":
12901312
additional_inputs = {"past_key_values": DynamicCache()}
1291-
transformers_outputs = transformers_model.generate(
1292-
**tokens, generation_config=gen_config, **additional_inputs
1293-
)
1313+
with patch_awq_for_inference("awq" in model_arch):
1314+
transformers_outputs = transformers_model.generate(
1315+
**tokens, generation_config=gen_config, **additional_inputs
1316+
)
12941317
set_seed(SEED)
12951318
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
12961319
self.assertTrue(

‎tests/openvino/utils_tests.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import unittest
15+
from contextlib import contextmanager
1516
from typing import Dict, List, Union
1617

1718
import numpy as np
@@ -81,12 +82,12 @@
8182
"longt5": "hf-internal-testing/tiny-random-longt5",
8283
"llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM",
8384
"llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM",
84-
"llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ",
8585
"llava": "katuni4ka/tiny-random-llava",
8686
"llava_next": "katuni4ka/tiny-random-llava-next",
8787
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
8888
"opt": "hf-internal-testing/tiny-random-OPTModel",
8989
"opt125m": "facebook/opt-125m",
90+
"opt_gptq": "ybelkada/opt-125m-gptq-4bit",
9091
"marian": "sshleifer/tiny-marian-en-de",
9192
"mbart": "hf-internal-testing/tiny-random-mbart",
9293
"minicpm": "katuni4ka/tiny-random-minicpm",
@@ -95,6 +96,7 @@
9596
"mistral": "echarlaix/tiny-random-mistral",
9697
"mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",
9798
"mixtral": "TitanML/tiny-mixtral",
99+
"mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit",
98100
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
99101
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
100102
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
@@ -226,6 +228,61 @@ def get_num_quantized_nodes(model):
226228
return num_fake_quantize, num_weight_nodes
227229

228230

231+
@contextmanager
232+
def mock_torch_cuda_is_available(to_patch):
233+
original_is_available = torch.cuda.is_available
234+
if to_patch:
235+
torch.cuda.is_available = lambda: True
236+
try:
237+
yield
238+
finally:
239+
if to_patch:
240+
torch.cuda.is_available = original_is_available
241+
242+
243+
@contextmanager
244+
def patch_awq_for_inference(to_patch):
245+
orig_gemm_forward = None
246+
if to_patch:
247+
# patch GEMM module to allow inference without CUDA GPU
248+
from awq.modules.linear.gemm import WQLinearMMFunction
249+
from awq.utils.packing_utils import dequantize_gemm
250+
251+
def new_forward(
252+
ctx,
253+
x,
254+
qweight,
255+
qzeros,
256+
scales,
257+
w_bit=4,
258+
group_size=128,
259+
bias=None,
260+
out_features=0,
261+
):
262+
ctx.out_features = out_features
263+
264+
out_shape = x.shape[:-1] + (out_features,)
265+
x = x.to(torch.float16)
266+
267+
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
268+
out = torch.matmul(x, out)
269+
270+
out = out + bias if bias is not None else out
271+
out = out.reshape(out_shape)
272+
273+
if len(out.shape) == 2:
274+
out = out.unsqueeze(0)
275+
return out
276+
277+
orig_gemm_forward = WQLinearMMFunction.forward
278+
WQLinearMMFunction.forward = new_forward
279+
try:
280+
yield
281+
finally:
282+
if orig_gemm_forward is not None:
283+
WQLinearMMFunction.forward = orig_gemm_forward
284+
285+
229286
def compare_num_quantized_nodes_per_model(
230287
test_case: unittest.TestCase,
231288
models: List[Union[ov.Model, OVBaseModel]],

0 commit comments

Comments
 (0)
Please sign in to comment.