Skip to content

Commit 7000397

Browse files
authored
Added GPTQ/AWQ support with HF Transformers (#1933)
1 parent 19744f5 commit 7000397

File tree

5 files changed

+121
-10
lines changed

5 files changed

+121
-10
lines changed

tools/who_what_benchmark/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ numpy>=1.23.5
88
tqdm>=4.66.1
99
diffusers
1010
datasets<3.2.0
11+
auto-gptq; sys_platform != "darwin"
12+
autoawq<0.2.8; sys_platform != "darwin"

tools/who_what_benchmark/tests/test_cli_text.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import logging
88
import json
9+
import sys
910

1011
from transformers import AutoTokenizer
1112
from optimum.intel.openvino import OVModelForCausalLM, OVWeightQuantizationConfig
@@ -27,6 +28,9 @@ def run_wwb(args):
2728
base_model_path = os.path.join(tmp_dir, "opt125m")
2829
target_model_path = os.path.join(tmp_dir, "opt125m_int8")
2930

31+
gptq_model_id = "ybelkada/opt-125m-gptq-4bit"
32+
awq_model_id = "TitanML/tiny-mixtral-AWQ-4bit"
33+
3034

3135
def setup_module():
3236
from optimum.exporters.openvino.convert import export_tokenizer
@@ -181,7 +185,21 @@ def test_text_language():
181185
assert "马克" in data["prompts"].values[0]
182186

183187

184-
def test_text_hf_model():
188+
hf_model_scope = [
189+
(model_id),
190+
]
191+
if sys.platform != 'darwin':
192+
hf_model_scope += [
193+
(gptq_model_id),
194+
(awq_model_id),
195+
]
196+
197+
198+
@pytest.mark.parametrize(
199+
("model_id"),
200+
hf_model_scope,
201+
)
202+
def test_text_hf_model(model_id):
185203
with tempfile.TemporaryDirectory() as temp_dir:
186204
temp_file_name = os.path.join(temp_dir, "gt.csv")
187205
result = run_wwb(
@@ -191,15 +209,15 @@ def test_text_hf_model():
191209
"--gt-data",
192210
temp_file_name,
193211
"--num-samples",
194-
"2",
212+
"1",
195213
"--device",
196214
"CPU",
197215
"--hf",
198216
]
199217
)
200218
assert result.returncode == 0
201219
data = pd.read_csv(temp_file_name)
202-
assert len(data["prompts"].values) == 2
220+
assert len(data["prompts"].values) == 1
203221

204222

205223
def test_text_genai_model():

tools/who_what_benchmark/whowhatbench/model_loaders.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import logging
22
import json
3+
import torch
34

45
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer
56
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting
67

8+
from .utils import mock_torch_cuda_is_available
9+
710

811
logging.basicConfig(level=logging.INFO)
912
logger = logging.getLogger(__name__)
@@ -82,18 +85,39 @@ def load_text_llamacpp_pipeline(model_dir):
8285
return model
8386

8487

88+
def load_text_hf_pipeline(model_id, device):
89+
model_kwargs = {}
90+
91+
if not torch.cuda.is_available or device.lower() == "cpu":
92+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
93+
is_gptq = False
94+
is_awq = False
95+
if getattr(config, "quantization_config", None):
96+
is_gptq = config.quantization_config["quant_method"] == "gptq"
97+
is_awq = config.quantization_config["quant_method"] == "awq"
98+
if is_gptq or is_awq:
99+
# infer in FP32
100+
model_kwargs["torch_dtype"] = torch.float32
101+
with mock_torch_cuda_is_available(is_gptq or is_awq):
102+
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="cpu", **model_kwargs)
103+
if is_awq:
104+
model.is_awq = is_awq
105+
else:
106+
model = AutoModelForCausalLM.from_pretrained(
107+
model_id, trust_remote_code=True, device_map=device.lower(), **model_kwargs
108+
)
109+
model.eval()
110+
return model
111+
112+
85113
def load_text_model(
86114
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, **kwargs,
87115
):
88116
if use_hf:
89117
logger.info("Using HF Transformers API")
90-
model = AutoModelForCausalLM.from_pretrained(
91-
model_id, trust_remote_code=True, device_map=device.lower()
92-
)
93-
model.eval()
118+
model = load_text_hf_pipeline(model_id, device)
94119
elif use_genai:
95120
model = load_text_genai_pipeline(model_id, device, ov_config, **kwargs)
96-
97121
elif use_llamacpp:
98122
logger.info("Using llama.cpp API")
99123
model = load_text_llamacpp_pipeline(model_id)

tools/who_what_benchmark/whowhatbench/text_evaluator.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .registry import register_evaluator, BaseEvaluator
88
from .whowhat_metrics import TextDivergency, TextSimilarity
9+
from .utils import patch_awq_for_inference
910

1011
default_data = {
1112
"en": {
@@ -187,17 +188,27 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
187188

188189
def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
189190
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
191+
is_awq = getattr(model, "is_awq", None) is not None
192+
190193
if use_chat_template:
191194
message = [{"role": "user", "content": prompt}]
192195
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
193-
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
196+
if is_awq:
197+
with patch_awq_for_inference(is_awq):
198+
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
199+
else:
200+
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
194201
if crop_question:
195202
tokens = tokens[:, inputs.shape[-1]:]
196203
res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
197204
return res
198205
else:
199206
inputs = self.tokenizer(prompt, return_tensors="pt")
200-
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
207+
if is_awq:
208+
with patch_awq_for_inference(is_awq):
209+
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
210+
else:
211+
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
201212
if crop_question:
202213
tokens = tokens[:, inputs["input_ids"].shape[-1] :]
203214
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

tools/who_what_benchmark/whowhatbench/utils.py

+56
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Union, Tuple, List, Optional
22
import torch
3+
from contextlib import contextmanager
34

45
from diffusers.utils import torch_utils
56

@@ -22,3 +23,58 @@ def new_randn_tensor(
2223

2324
def patch_diffusers():
2425
torch_utils.randn_tensor = new_randn_tensor
26+
27+
28+
@contextmanager
29+
def mock_torch_cuda_is_available(to_patch):
30+
original_is_available = torch.cuda.is_available
31+
if to_patch:
32+
torch.cuda.is_available = lambda: True
33+
try:
34+
yield
35+
finally:
36+
if to_patch:
37+
torch.cuda.is_available = original_is_available
38+
39+
40+
@contextmanager
41+
def patch_awq_for_inference(to_patch):
42+
orig_gemm_forward = None
43+
if to_patch:
44+
# patch GEMM module to allow inference without CUDA GPU
45+
from awq.modules.linear.gemm import WQLinearMMFunction
46+
from awq.utils.packing_utils import dequantize_gemm
47+
48+
def new_forward(
49+
ctx,
50+
x,
51+
qweight,
52+
qzeros,
53+
scales,
54+
w_bit=4,
55+
group_size=128,
56+
bias=None,
57+
out_features=0,
58+
):
59+
ctx.out_features = out_features
60+
61+
out_shape = x.shape[:-1] + (out_features,)
62+
x = x.to(torch.float16)
63+
64+
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
65+
out = torch.matmul(x, out)
66+
67+
out = out + bias if bias is not None else out
68+
out = out.reshape(out_shape)
69+
70+
if len(out.shape) == 2:
71+
out = out.unsqueeze(0)
72+
return out
73+
74+
orig_gemm_forward = WQLinearMMFunction.forward
75+
WQLinearMMFunction.forward = new_forward
76+
try:
77+
yield
78+
finally:
79+
if orig_gemm_forward is not None:
80+
WQLinearMMFunction.forward = orig_gemm_forward

0 commit comments

Comments
 (0)