Skip to content

Commit 45a2c1a

Browse files
yiliu30Yi4Liu
andauthored
Add qdq eval (#2121)
* add eval Change-Id: I7ce64ede965976dd79e979aace82f4d251cc6803 Signed-off-by: Yi Liu <yiliu4@habana.ai> * fix Change-Id: I72305d9d6ef6e3588bc8361f62baeeca06f42848 Signed-off-by: Yi Liu <yiliu4@habana.ai> * add float model Change-Id: Ia46444d77d349b1a976e6d7031d06bb621d6d7e4 Signed-off-by: Yi Liu <yiliu4@habana.ai> * add prompt Change-Id: Ie7b35f45d8f67a655dc9fb06eda824eb8a7f56c1 Signed-off-by: Yi Liu <yiliu4@habana.ai> --------- Signed-off-by: Yi Liu <yiliu4@habana.ai> Co-authored-by: Yi Liu <yiliu4@habana.ai>
1 parent 54a88b7 commit 45a2c1a

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed

examples/ds/eval.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
import torch
3+
import tqdm
4+
from loguru import logger
5+
import logging
6+
import safetensors
7+
from safetensors import safe_open
8+
from safetensors.torch import save_file
9+
import json
10+
11+
logging.basicConfig(level=logging.DEBUG)
12+
torch.set_grad_enabled(False)
13+
14+
# CONSTANTS
15+
SAFETENSORS = "safetensors"
16+
WEIGHT_SCALE_NAME = "scale_weight"
17+
INPUT_SCALE_NAME = "scale_input"
18+
SCALE_DTYPE = torch.bfloat16
19+
SCALE_FILE_NAME = f"scales.{SAFETENSORS}"
20+
FULL_RANGE = torch.finfo(torch.float8_e4m3fn).max
21+
WEIGHT_BACKOFF = 0.5
22+
QUANT_MODULE_TYPES = (torch.nn.Linear,)
23+
SKIP_WEIGHT_LST = {
24+
"model.norm",
25+
"layernorm",
26+
"e_score_correction_bias",
27+
# "lm_head.weight",
28+
"embed_tokens",
29+
"mlp.gate.weight", # mlp.gate is not linear
30+
}
31+
"""
32+
# https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html?highlight=backoff#supported-json-config-file-options
33+
Similarly, the maxabs value of a weight is scaled to weight_backoff*FP8_143_FULLSCALE. The default values are input_backoff=0.25 and weight_backoff=0.5.
34+
"""
35+
MODEL_STATE_DICT_MAPPING_FILENAME = "model.safetensors.index.json"
36+
37+
38+
def skip_weight(weight_name):
39+
return any([skip_name in weight_name for skip_name in SKIP_WEIGHT_LST])
40+
41+
42+
def get_cpu_mem_size_in_gb():
43+
import psutil
44+
45+
mem = psutil.virtual_memory()
46+
return mem.available
47+
48+
49+
from quant import quant_tensor
50+
51+
52+
from torch import nn
53+
54+
55+
# Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/1d044fd82b15f1cedb197a288e50cc96a2c27205/inference/model.py#L91-L108
56+
class FP8QDQLinear(torch.nn.Linear):
57+
dtype = torch.bfloat16
58+
fp8_dtype = torch.float8_e4m3fn
59+
60+
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None):
61+
super().__init__(in_features, out_features, bias=bias)
62+
self.in_features = in_features
63+
self.out_features = out_features
64+
self.weight = nn.Parameter(
65+
torch.empty(out_features, in_features, dtype=FP8QDQLinear.fp8_dtype), requires_grad=True
66+
)
67+
self.scale_weight = nn.Parameter(torch.tensor(0, dtype=FP8QDQLinear.dtype), requires_grad=False)
68+
if bias:
69+
self.bias = nn.Parameter(torch.empty(out_features))
70+
else:
71+
self.register_parameter("bias", None)
72+
73+
def dequant_weight_online(self):
74+
fp8_weight = self.weight
75+
qdq_weight = fp8_weight.to(FP8QDQLinear.dtype) * self.scale_weight
76+
return qdq_weight
77+
78+
def qdq_input(self, bf16_input: torch.Tensor):
79+
input_scale, input_fp8 = quant_tensor(bf16_input)
80+
qdq_input_bf16 = input_fp8.to(FP8QDQLinear.dtype) * input_scale
81+
return qdq_input_bf16
82+
83+
@classmethod
84+
def create_from_linear(cls, linear: nn.Linear):
85+
qdq_linear = cls(linear.in_features, linear.out_features)
86+
qdq_linear.weight.data = linear.weight.data
87+
if linear.bias is not None:
88+
qdq_linear.bias = linear.bias
89+
return qdq_linear
90+
91+
def forward(self, bf16_input: torch.Tensor) -> torch.Tensor:
92+
qdq_input = self.qdq_input(bf16_input)
93+
qdq_weight = self.dequant_weight_online()
94+
out = torch.nn.functional.linear(qdq_input, qdq_weight, self.bias)
95+
return out
96+
97+
98+
def patch_lin():
99+
logger.warning("Patching torch.nn.Linear to FP8QDQLinear")
100+
torch.nn.Linear = FP8QDQLinear
101+
102+
103+
def qdq_eval(model_path, not_patch_lin=False):
104+
import transformers
105+
from transformers.modeling_utils import no_init_weights
106+
from patch_for_ds import patch_transformers
107+
108+
if not not_patch_lin:
109+
patch_lin()
110+
111+
def _patch__initialize_weights(self, module):
112+
print(f"Skipping init_weights ")
113+
module._is_hf_initialized = True
114+
115+
transformers.modeling_utils.PreTrainedModel._initialize_weights = _patch__initialize_weights
116+
patch_transformers()
117+
with no_init_weights():
118+
model = transformers.AutoModelForCausalLM.from_pretrained(
119+
model_path,
120+
torch_dtype="auto",
121+
low_cpu_mem_usage=True,
122+
trust_remote_code=True,
123+
)
124+
logger.info(f"Patched model: {model}")
125+
model.eval()
126+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
127+
prompt = "Hi, who"
128+
encode = tokenizer.encode(prompt, return_tensors="pt")
129+
with torch.no_grad():
130+
output_tokens = model.generate(encode, max_length=10)
131+
output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
132+
logger.info(f"Prompt: {prompt}")
133+
logger.info(f"Output: {output}")
134+
135+
136+
if __name__ == "__main__":
137+
import argparse
138+
139+
parser = argparse.ArgumentParser()
140+
parser.add_argument("--qmodel_path", type=str, required=True)
141+
parser.add_argument("--not_patch_lin", action="store_true", help="Measure float model")
142+
args = parser.parse_args()
143+
qdq_eval(args.qmodel_path, not_patch_lin=args.not_patch_lin)

examples/ds/patch_for_ds.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# ==--------------------------------------------------------------------------==
2-
# Patch for loading DS models
2+
# Patch for loading DS models from transformers
33
from typing import Union, Optional
44
import torch
55
import os
@@ -101,7 +101,7 @@ def load_state_dict(
101101
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
102102
)
103103

104-
104+
# https://github.com/huggingface/transformers/pull/35493
105105
def set_initialized_submodules(model, state_dict_keys):
106106
"""
107107
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state

0 commit comments

Comments
 (0)