Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GPTQ/AWQ support with HF Transformers #1933

Merged
merged 10 commits into from
Mar 19, 2025
17 changes: 14 additions & 3 deletions tools/who_what_benchmark/tests/test_cli_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def run_wwb(args):
base_model_path = os.path.join(tmp_dir, "opt125m")
target_model_path = os.path.join(tmp_dir, "opt125m_int8")

gptq_model_id = "ybelkada/opt-125m-gptq-4bit"
awq_model_id = "TitanML/tiny-mixtral-AWQ-4bit"


def setup_module():
from optimum.exporters.openvino.convert import export_tokenizer
Expand Down Expand Up @@ -181,7 +184,15 @@ def test_text_language():
assert "马克" in data["prompts"].values[0]


def test_text_hf_model():
@pytest.mark.parametrize(
("model_id"),
[
(model_id),
(gptq_model_id),
(awq_model_id),
],
)
def test_text_hf_model(model_id):
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_name = os.path.join(temp_dir, "gt.csv")
result = run_wwb(
Expand All @@ -191,15 +202,15 @@ def test_text_hf_model():
"--gt-data",
temp_file_name,
"--num-samples",
"2",
"1",
"--device",
"CPU",
"--hf",
]
)
assert result.returncode == 0
data = pd.read_csv(temp_file_name)
assert len(data["prompts"].values) == 2
assert len(data["prompts"].values) == 1


def test_text_genai_model():
Expand Down
34 changes: 29 additions & 5 deletions tools/who_what_benchmark/whowhatbench/model_loaders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import json
import torch

from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, AutoTokenizer
from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, AutoPipelineForInpainting

from .utils import mock_torch_cuda_is_available


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,18 +85,39 @@ def load_text_llamacpp_pipeline(model_dir):
return model


def load_text_hf_pipeline(model_id, device):
model_kwargs = {}

if not torch.cuda.is_available or device.lower() == "cpu":
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
is_gptq = False
is_awq = False
if getattr(config, "quantization_config", None):
is_gptq = config.quantization_config["quant_method"] == "gptq"
is_awq = config.quantization_config["quant_method"] == "awq"
if is_gptq or is_awq:
# infer in FP32
model_kwargs["torch_dtype"] = torch.float32
with mock_torch_cuda_is_available(is_gptq or is_awq):
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map="cpu", **model_kwargs)
if is_awq:
model.is_awq = is_awq
else:
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, device_map=device.lower(), **model_kwargs
)
model.eval()
return model


def load_text_model(
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False, use_llamacpp=False, **kwargs,
):
if use_hf:
logger.info("Using HF Transformers API")
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, device_map=device.lower()
)
model.eval()
model = load_text_hf_pipeline(model_id, device)
elif use_genai:
model = load_text_genai_pipeline(model_id, device, ov_config, **kwargs)

elif use_llamacpp:
logger.info("Using llama.cpp API")
model = load_text_llamacpp_pipeline(model_id)
Expand Down
15 changes: 13 additions & 2 deletions tools/who_what_benchmark/whowhatbench/text_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .registry import register_evaluator, BaseEvaluator
from .whowhat_metrics import TextDivergency, TextSimilarity
from .utils import patch_awq_for_inference

default_data = {
"en": {
Expand Down Expand Up @@ -187,17 +188,27 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):

def _generate_data(self, model, gen_answer_fn=None, generation_config=None):
def default_gen_answer(model, tokenizer, prompt, max_new_tokens, crop_question, use_chat_template=False):
is_awq = getattr(model, "is_awq", None) is not None

if use_chat_template:
message = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
if is_awq:
with patch_awq_for_inference(is_awq):
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
else:
tokens = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens)
if crop_question:
tokens = tokens[:, inputs.shape[-1]:]
res = self.tokenizer.decode(tokens[0], skip_special_tokens=True)
return res
else:
inputs = self.tokenizer(prompt, return_tensors="pt")
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
if is_awq:
with patch_awq_for_inference(is_awq):
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
else:
tokens = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens)
if crop_question:
tokens = tokens[:, inputs["input_ids"].shape[-1] :]
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
Expand Down
56 changes: 56 additions & 0 deletions tools/who_what_benchmark/whowhatbench/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Tuple, List, Optional
import torch
from contextlib import contextmanager

from diffusers.utils import torch_utils

Expand All @@ -22,3 +23,58 @@ def new_randn_tensor(

def patch_diffusers():
torch_utils.randn_tensor = new_randn_tensor


@contextmanager
def mock_torch_cuda_is_available(to_patch):
original_is_available = torch.cuda.is_available
if to_patch:
torch.cuda.is_available = lambda: True
try:
yield
finally:
if to_patch:
torch.cuda.is_available = original_is_available


@contextmanager
def patch_awq_for_inference(to_patch):
orig_gemm_forward = None
if to_patch:
# patch GEMM module to allow inference without CUDA GPU
from awq.modules.linear.gemm import WQLinearMMFunction
from awq.utils.packing_utils import dequantize_gemm

def new_forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0,
):
ctx.out_features = out_features

out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)

out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)

out = out + bias if bias is not None else out
out = out.reshape(out_shape)

if len(out.shape) == 2:
out = out.unsqueeze(0)
return out

orig_gemm_forward = WQLinearMMFunction.forward
WQLinearMMFunction.forward = new_forward
try:
yield
finally:
if orig_gemm_forward is not None:
WQLinearMMFunction.forward = orig_gemm_forward
Loading