Skip to content

Commit 8bdd98c

Browse files
committed
update fp8 implementation, design and implement save&load
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent c4010bc commit 8bdd98c

File tree

18 files changed

+2128
-235
lines changed

18 files changed

+2128
-235
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+84-40
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
1+
import os
2+
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
3+
# os.environ["GRAPH_VISUALIZATION"] = "True"
4+
import shutil
5+
shutil.rmtree(".graph_dumps", ignore_errors=True)
16
import argparse
27
import time
38
import json
49
import re
510
import torch
6-
import transformers
7-
import os
11+
import habana_frameworks.torch.hpex
12+
import torch.nn.functional as F
813
import deepspeed
14+
import transformers
915
from transformers import AutoModelForCausalLM, AutoTokenizer
10-
import habana_frameworks.torch.hpex
11-
from habana_frameworks.torch.hpu import memory_stats
16+
import habana_frameworks.torch.core as htcore
1217
import numpy as np
1318
import lm_eval
1419
import lm_eval.tasks
1520
import lm_eval.evaluator
21+
22+
1623
torch.set_grad_enabled(False)
24+
htcore.hpu_set_env()
25+
torch.device('hpu')
1726

1827

1928
def itrex_bootstrap_stderr(f, xs, iters):
@@ -57,16 +66,16 @@ def itrex_bootstrap_stderr(f, xs, iters):
5766
help="Pad input ids to max length.")
5867
parser.add_argument("--calib_iters", default=100, type=int,
5968
help="calibration iters.")
60-
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
61-
choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \
69+
parser.add_argument("--tasks", nargs='+', default=["hellaswag", "lambada_openai", "piqa", "winogrande"], \
70+
type=str, choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \
6271
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"],
6372
help="tasks list for accuracy validation")
6473
parser.add_argument("--limit", default=None, type=int,
6574
help="the sample num of evaluation.")
6675
parser.add_argument("--max_new_tokens", default=100, type=int,
6776
help="calibration iters.")
6877
parser.add_argument('--buckets', type=int, nargs='+', \
69-
help="Input length buckets to use with static_shapes", default=[129])
78+
help="Input length buckets to use with static_shapes", default=[256, 512])
7079
parser.add_argument("--local_rank",
7180
type=int,
7281
default=-1,
@@ -78,53 +87,48 @@ def itrex_bootstrap_stderr(f, xs, iters):
7887
world_size = int(os.getenv('WORLD_SIZE', '1'))
7988
local_rank = int(os.getenv('LOCAL_RANK', '-1'))
8089

81-
#if local_rank == 0:
82-
# os.environ["ENABLE_CONSOLE"] = 'True'
83-
# os.environ["LOG_LEVEL_ALL"] = '0'
8490

85-
# model
91+
model_dtype = torch.float32
8692
if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()):
8793
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
88-
torch.device('hpu')
8994
config = AutoConfig.from_pretrained(args.model)
9095
if world_size > 1:
91-
model_dtype = torch.bfloat16
96+
model_dtype = torch.float16
9297
deepspeed.init_distributed(dist_backend="hccl")
9398
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
9499
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
95100
import tempfile
96101
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
97-
from utils import write_checkpoints_json
102+
from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana
98103
write_checkpoints_json(
99104
args.model,
100105
local_rank,
101106
checkpoints_json,
102107
token=None,
103108
)
104-
elif re.search("llama", args.model.lower()):
105-
from models.modeling_llama import LlamaForCausalLM
106-
user_model = LlamaForCausalLM.from_pretrained(
107-
args.model,
108-
device_map='hpu',
109-
)
110109
else:
111110
user_model = AutoModelForCausalLM.from_pretrained(
112111
args.model,
113112
device_map='hpu',
113+
torch_dtype=model_dtype,
114114
)
115115
elif re.search("chatglm", args.model.lower()):
116116
from models.modeling_chatglm import ChatGLMForConditionalGeneration
117117
user_model = ChatGLMForConditionalGeneration.from_pretrained(
118118
args.model,
119119
revision=args.revision,
120120
device_map='hpu',
121+
torch_dtype=model_dtype,
121122
)
123+
# print(user_model.transformer.output_layer.weight.dtype) # always fp16
124+
user_model.float() # static fp8 need float32 for graph compiler
122125
else:
123126
user_model = AutoModelForCausalLM.from_pretrained(
124127
args.model,
125128
trust_remote_code=args.trust_remote_code,
126129
revision=args.revision,
127130
device_map='hpu',
131+
torch_dtype=model_dtype,
128132
)
129133

130134
# tokenizer
@@ -140,6 +144,8 @@ def itrex_bootstrap_stderr(f, xs, iters):
140144
trust_remote_code=args.trust_remote_code
141145
)
142146

147+
tokenizer.pad_token = tokenizer.eos_token
148+
143149
if world_size > 1:
144150
if re.search("llama", args.model.lower()):
145151
ds_inference_kwargs = {"dtype": model_dtype}
@@ -160,7 +166,7 @@ def itrex_bootstrap_stderr(f, xs, iters):
160166

161167
if args.approach in ["dynamic", "static"]:
162168
print("device:", next(user_model.parameters()).device)
163-
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
169+
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
164170
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
165171
from neural_compressor.torch.quantization import quantize
166172
if args.precision == "fp8_e4m3":
@@ -169,15 +175,15 @@ def itrex_bootstrap_stderr(f, xs, iters):
169175
dtype = torch.float8_e5m2
170176
if args.approach == "dynamic":
171177
#user_model = quantize_dynamic(user_model, dtype, inplace=True)
172-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
178+
qconfig = FP8Config(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
173179
if args.skip_lm_head:
174-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
180+
fp32_config = FP8Config(weight_dtype=torch.float32, act_dtype=torch.float32)
175181
qconfig.set_local("lm_head", fp32_config)
176182
user_model = quantize_dynamic(user_model, qconfig, inplace=True)
177183
elif args.approach == "static":
178-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static")
184+
qconfig = FP8Config(weight_dtype=dtype, act_dtype=dtype, approach="static")
179185
if args.skip_lm_head:
180-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
186+
fp32_config = FP8Config(weight_dtype=torch.float32, act_dtype=torch.float32)
181187
qconfig.set_local("lm_head", fp32_config)
182188
# dataset
183189
from datasets import load_dataset
@@ -186,7 +192,13 @@ def itrex_bootstrap_stderr(f, xs, iters):
186192
calib_data = []
187193
for examples in calib_dataset:
188194
calib_data.append(
189-
tokenizer(examples["text"], return_tensors="pt", max_length=128)
195+
tokenizer(
196+
examples["text"],
197+
return_tensors="pt",
198+
max_length=64,
199+
padding="max_length",
200+
truncation=True
201+
)
190202
)
191203

192204
def calib_func(model):
@@ -199,6 +211,17 @@ def calib_func(model):
199211
)
200212

201213
user_model = quantize(user_model, qconfig, calib_func, inplace=True)
214+
# replace torch.matmul and toch.bmm by injection
215+
def replace_torch_mm_bmm():
216+
from neural_compressor.torch.amp.fp8.functions import fp8_matmul
217+
torch.matmul = fp8_matmul
218+
torch.bmm = fp8_matmul
219+
220+
replace_torch_mm_bmm()
221+
# It enables weights constant folding
222+
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
223+
_mark_params_as_const(user_model) # can reduce memory allocated and speed up
224+
_check_params_as_const(user_model)
202225
print(user_model, flush=True)
203226

204227
if args.to_graph:
@@ -244,6 +267,16 @@ def calib_func(model):
244267

245268
if args.accuracy:
246269

270+
def save_to_excel(dict):
271+
import pandas as pd
272+
df_new = pd.DataFrame(dict)
273+
try:
274+
df_existing = pd.read_excel('output.xlsx')
275+
except FileNotFoundError:
276+
df_existing = pd.DataFrame()
277+
df_combined = pd.concat([df_existing, df_new], axis=0, ignore_index=True)
278+
df_combined.to_excel('output.xlsx', index=False, engine='openpyxl', header=True)
279+
247280
class HabanaModelAdapter(lm_eval.base.BaseLM):
248281
def __init__(self, tokenizer, model, args, options):
249282
super().__init__()
@@ -292,16 +325,14 @@ def find_bucket(self, length):
292325
return [b for b in self.buckets if b >= length][0]
293326

294327
def _model_call(self, inps):
295-
#print(inps.shape)
296328
seq_length = inps.shape[-1]
329+
padding_length = 0
297330
bucket_length = self.find_bucket(seq_length)
298331
padding_length = bucket_length - seq_length
299-
if True:
300-
import torch.nn.functional as F
301-
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
332+
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
333+
logits = self.model(inps.to(self._device))["logits"].cpu()
302334

303-
logits = self.model(inps.to(self._device))['logits']
304-
if True and padding_length > 0:
335+
if padding_length > 0:
305336
logits = logits[:, :-padding_length, :]
306337
logits = logits.to(torch.float32)
307338
return logits
@@ -333,18 +364,31 @@ def _model_call(self, inps):
333364

334365

335366
dumped = json.dumps(results, indent=2)
367+
accu_dict = {}
368+
case_name = args.approach + "-" + args.precision
336369
for task_name in args.tasks:
337370
if task_name == "wikitext":
338371
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True)
372+
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]]
339373
else:
340374
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True)
375+
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]]
376+
save_to_excel(accu_dict)
377+
341378

342379
# show memory usage
343-
mem_stats = memory_stats()
344-
mem_dict = {
345-
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
346-
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
347-
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
348-
}
349-
for k, v in mem_dict.items():
350-
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
380+
def show_msg():
381+
import numpy as np
382+
import glob
383+
from habana_frameworks.torch.hpu import memory_stats
384+
print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*")))
385+
mem_stats = memory_stats()
386+
mem_dict = {
387+
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
388+
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
389+
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
390+
}
391+
for k, v in mem_dict.items():
392+
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
393+
394+
show_msg()

neural_compressor/torch/algorithms/habana_fp8/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
from .fp8_quant import quantize_dynamic, quantize, white_list
16+
from .save_load import save, load

0 commit comments

Comments
 (0)