Skip to content

Commit f812e67

Browse files
authoredFeb 27, 2024
update fp8 implementation, design and implement save&load (#1605)
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent a8d81ca commit f812e67

File tree

31 files changed

+2676
-510
lines changed

31 files changed

+2676
-510
lines changed
 

‎examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py ‎examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+146-93
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
1+
import os
2+
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
3+
os.environ["USE_GAUDI2_SCALE"] = "True"
4+
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
5+
# os.environ["GRAPH_VISUALIZATION"] = "True"
6+
# import shutil
7+
# shutil.rmtree(".graph_dumps", ignore_errors=True)
18
import argparse
29
import time
310
import json
411
import re
512
import torch
6-
import transformers
7-
import os
8-
import deepspeed
9-
from transformers import AutoModelForCausalLM, AutoTokenizer
1013
import habana_frameworks.torch.hpex
11-
from habana_frameworks.torch.hpu import memory_stats
14+
import torch.nn.functional as F
15+
import deepspeed
16+
import transformers
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
18+
import habana_frameworks.torch.core as htcore
1219
import numpy as np
1320
import lm_eval
1421
import lm_eval.tasks
1522
import lm_eval.evaluator
16-
torch.set_grad_enabled(False)
23+
from accelerate import init_empty_weights
24+
from utils import itrex_bootstrap_stderr, show_msg, save_to_excel
1725

1826

19-
def itrex_bootstrap_stderr(f, xs, iters):
20-
from lm_eval.metrics import _bootstrap_internal, sample_stddev
21-
res = []
22-
chunk_size = min(1000, iters)
23-
it = _bootstrap_internal(f, chunk_size)
24-
for i in range(iters // chunk_size):
25-
bootstrap = it((i, xs))
26-
res.extend(bootstrap)
27-
return sample_stddev(res)
27+
torch.set_grad_enabled(False)
28+
htcore.hpu_set_env()
29+
torch.device('hpu')
30+
2831

2932
# to avoid out-of-memory caused by Popen for large language models.
3033
lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr
@@ -51,22 +54,26 @@ def itrex_bootstrap_stderr(f, xs, iters):
5154
parser.add_argument("--accuracy", action="store_true")
5255
parser.add_argument("--performance", action="store_true")
5356
parser.add_argument("--generate", action="store_true")
57+
parser.add_argument("--skip_fp8_mm", action="store_true")
58+
parser.add_argument("--dump_to_excel", action="store_true")
59+
parser.add_argument("--save", action="store_true")
60+
parser.add_argument("--load", action="store_true")
5461
parser.add_argument("--batch_size", default=1, type=int,
5562
help="For accuracy measurement only.")
5663
parser.add_argument("--pad_max_length", default=512, type=int,
5764
help="Pad input ids to max length.")
5865
parser.add_argument("--calib_iters", default=100, type=int,
5966
help="calibration iters.")
60-
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
61-
choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \
62-
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"],
67+
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \
68+
type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa",
69+
"rte", "openbookqa", "lambada_standard", "wikitext"],
6370
help="tasks list for accuracy validation")
6471
parser.add_argument("--limit", default=None, type=int,
6572
help="the sample num of evaluation.")
6673
parser.add_argument("--max_new_tokens", default=100, type=int,
6774
help="calibration iters.")
6875
parser.add_argument('--buckets', type=int, nargs='+', \
69-
help="Input length buckets to use with static_shapes", default=[129])
76+
help="Input length buckets to use with static_shapes", default=[256, 512])
7077
parser.add_argument("--local_rank",
7178
type=int,
7279
default=-1,
@@ -78,67 +85,65 @@ def itrex_bootstrap_stderr(f, xs, iters):
7885
world_size = int(os.getenv('WORLD_SIZE', '1'))
7986
local_rank = int(os.getenv('LOCAL_RANK', '-1'))
8087

81-
#if local_rank == 0:
82-
# os.environ["ENABLE_CONSOLE"] = 'True'
83-
# os.environ["LOG_LEVEL_ALL"] = '0'
8488

85-
# model
89+
model_dtype = torch.float32
8690
if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()):
8791
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
88-
torch.device('hpu')
89-
config = AutoConfig.from_pretrained(args.model)
9092
if world_size > 1:
91-
model_dtype = torch.bfloat16
93+
config = AutoConfig.from_pretrained(args.model)
94+
model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype
9295
deepspeed.init_distributed(dist_backend="hccl")
9396
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
9497
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
9598
import tempfile
9699
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
97-
from utils import write_checkpoints_json
100+
from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana
98101
write_checkpoints_json(
99-
args.model,
100-
local_rank,
101-
checkpoints_json,
102-
token=None,
102+
args.model,
103+
local_rank,
104+
checkpoints_json,
105+
token=None,
103106
)
104-
elif re.search("llama", args.model.lower()):
105-
from models.modeling_llama import LlamaForCausalLM
106-
user_model = LlamaForCausalLM.from_pretrained(
107+
else:
108+
if args.load:
109+
config = AutoConfig.from_pretrained(args.model)
110+
with init_empty_weights():
111+
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
112+
else:
113+
user_model = AutoModelForCausalLM.from_pretrained(
114+
args.model,
115+
device_map='hpu',
116+
torch_dtype=model_dtype,
117+
)
118+
elif re.search("chatglm", args.model.lower()):
119+
if args.load:
120+
config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype)
121+
with init_empty_weights():
122+
user_model = AutoModelForCausalLM.from_config(config)
123+
else:
124+
from models.modeling_chatglm import ChatGLMForConditionalGeneration
125+
user_model = ChatGLMForConditionalGeneration.from_pretrained(
107126
args.model,
127+
revision=args.revision,
108128
device_map='hpu',
129+
torch_dtype=model_dtype,
109130
)
131+
# print(user_model.transformer.output_layer.weight.dtype) # always fp16
132+
user_model.float() # static fp8 need float32 for graph compiler
133+
else:
134+
if args.load:
135+
config = AutoConfig.from_pretrained(args.model)
136+
with init_empty_weights():
137+
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
110138
else:
111139
user_model = AutoModelForCausalLM.from_pretrained(
112140
args.model,
141+
trust_remote_code=args.trust_remote_code,
142+
revision=args.revision,
113143
device_map='hpu',
144+
torch_dtype=model_dtype,
114145
)
115-
elif re.search("chatglm", args.model.lower()):
116-
from models.modeling_chatglm import ChatGLMForConditionalGeneration
117-
user_model = ChatGLMForConditionalGeneration.from_pretrained(
118-
args.model,
119-
revision=args.revision,
120-
device_map='hpu',
121-
)
122-
else:
123-
user_model = AutoModelForCausalLM.from_pretrained(
124-
args.model,
125-
trust_remote_code=args.trust_remote_code,
126-
revision=args.revision,
127-
device_map='hpu',
128-
)
129146

130-
# tokenizer
131-
if re.search("baichuan", args.model.lower()):
132-
from models.tokenization_baichuan import BaichuanTokenizer
133-
tokenizer = BaichuanTokenizer.from_pretrained(
134-
args.model,
135-
trust_remote_code=args.trust_remote_code
136-
)
137-
else:
138-
tokenizer = AutoTokenizer.from_pretrained(
139-
args.model,
140-
trust_remote_code=args.trust_remote_code
141-
)
142147

143148
if world_size > 1:
144149
if re.search("llama", args.model.lower()):
@@ -148,36 +153,44 @@ def itrex_bootstrap_stderr(f, xs, iters):
148153
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
149154
ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}
150155
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
151-
152156
ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs)
153157
else:
154158
ds_model = deepspeed.init_inference(user_model,
155159
mp_size=world_size,
156160
replace_with_kernel_inject=False)
157161
user_model = ds_model.module
158162

163+
164+
# tokenizer
165+
if re.search("baichuan", args.model.lower()):
166+
from models.tokenization_baichuan import BaichuanTokenizer
167+
tokenizer = BaichuanTokenizer.from_pretrained(
168+
args.model,
169+
trust_remote_code=args.trust_remote_code
170+
)
171+
else:
172+
tokenizer = AutoTokenizer.from_pretrained(
173+
args.model,
174+
trust_remote_code=args.trust_remote_code
175+
)
176+
177+
159178
user_model.eval()
160179

161-
if args.approach in ["dynamic", "static"]:
180+
181+
### dynamic & static quantization ###
182+
if args.approach in ["dynamic", "static"] and not args.load:
162183
print("device:", next(user_model.parameters()).device)
163-
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
164-
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
184+
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
165185
from neural_compressor.torch.quantization import quantize
166-
if args.precision == "fp8_e4m3":
167-
dtype = torch.float8_e4m3fn
168-
else:
169-
dtype = torch.float8_e5m2
186+
dtype = args.precision
170187
if args.approach == "dynamic":
171-
#user_model = quantize_dynamic(user_model, dtype, inplace=True)
172-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
173-
if args.skip_lm_head:
174-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
175-
qconfig.set_local("lm_head", fp32_config)
176-
user_model = quantize_dynamic(user_model, qconfig, inplace=True)
188+
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
189+
user_model = quantize_dynamic(user_model, dtype, inplace=True)
177190
elif args.approach == "static":
178-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static")
191+
qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static")
179192
if args.skip_lm_head:
180-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
193+
fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32")
181194
qconfig.set_local("lm_head", fp32_config)
182195
# dataset
183196
from datasets import load_dataset
@@ -186,7 +199,13 @@ def itrex_bootstrap_stderr(f, xs, iters):
186199
calib_data = []
187200
for examples in calib_dataset:
188201
calib_data.append(
189-
tokenizer(examples["text"], return_tensors="pt", max_length=128)
202+
tokenizer(
203+
examples["text"],
204+
return_tensors="pt",
205+
max_length=64,
206+
padding="max_length",
207+
truncation=True
208+
)
190209
)
191210

192211
def calib_func(model):
@@ -199,12 +218,46 @@ def calib_func(model):
199218
)
200219

201220
user_model = quantize(user_model, qconfig, calib_func, inplace=True)
202-
print(user_model, flush=True)
221+
# saving
222+
if args.save and local_rank in [-1, 0]:
223+
user_model.save("saved_results")
224+
225+
226+
if args.load:
227+
from neural_compressor.torch.quantization import load
228+
user_model = load(user_model, "saved_results")
229+
230+
231+
if args.approach in ["dynamic", "static"] or args.load:
232+
# It enables weights constant folding
233+
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
234+
_mark_params_as_const(user_model) # can reduce memory allocated and speed up
235+
_check_params_as_const(user_model)
236+
237+
238+
239+
# If torch.matmul and torch.bmm are not replaced by INC module,
240+
# Below codes can make torch.matmul and torch.bmm run on fp8 by injection.
241+
if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']:
242+
def replace_torch_mm_bmm():
243+
from neural_compressor.torch.amp.fp8.functions import fp8_matmul
244+
torch.matmul = fp8_matmul
245+
torch.bmm = fp8_matmul
203246

247+
replace_torch_mm_bmm()
248+
249+
250+
# inference optimization
204251
if args.to_graph:
205252
import habana_frameworks.torch.hpu.graphs as htgraphs
206253
user_model = htgraphs.wrap_in_hpu_graph(user_model)
207254

255+
256+
# dump message of HPU after quantization or reloading
257+
show_msg()
258+
259+
260+
### generation, performance and accuracy validation ###
208261
if args.generate:
209262
input_prompt = "Here is my prompt"
210263
print("Prompt sentence:", input_prompt)
@@ -234,6 +287,7 @@ def calib_func(model):
234287
print("Generated sentence:", output_sentence)
235288
print("Duration:", eval_end - eval_start)
236289

290+
237291
if args.performance:
238292
eval_start = time.perf_counter()
239293
input_prompt = "Intel is a company which"
@@ -242,6 +296,7 @@ def calib_func(model):
242296
outputs = user_model.generate(input_tokens, **generation_config)
243297
print("Duration of generating 100 tokens :", time.perf_counter() - eval_start)
244298

299+
245300
if args.accuracy:
246301

247302
class HabanaModelAdapter(lm_eval.base.BaseLM):
@@ -292,16 +347,14 @@ def find_bucket(self, length):
292347
return [b for b in self.buckets if b >= length][0]
293348

294349
def _model_call(self, inps):
295-
#print(inps.shape)
296350
seq_length = inps.shape[-1]
351+
padding_length = 0
297352
bucket_length = self.find_bucket(seq_length)
298353
padding_length = bucket_length - seq_length
299-
if True:
300-
import torch.nn.functional as F
301-
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
354+
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
355+
logits = self.model(inps.to(self._device))["logits"].cpu()
302356

303-
logits = self.model(inps.to(self._device))['logits']
304-
if True and padding_length > 0:
357+
if padding_length > 0:
305358
logits = logits[:, :-padding_length, :]
306359
logits = logits.to(torch.float32)
307360
return logits
@@ -333,18 +386,18 @@ def _model_call(self, inps):
333386

334387

335388
dumped = json.dumps(results, indent=2)
389+
accu_dict = {}
390+
case_name = args.approach + "-" + args.precision
336391
for task_name in args.tasks:
337392
if task_name == "wikitext":
338393
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True)
394+
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]]
339395
else:
340396
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True)
397+
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]]
398+
if args.dump_to_excel and local_rank in [-1, 0]:
399+
save_to_excel(accu_dict)
400+
341401

342-
# show memory usage
343-
mem_stats = memory_stats()
344-
mem_dict = {
345-
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
346-
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
347-
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
348-
}
349-
for k, v in mem_dict.items():
350-
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
402+
# dump final message of HPU
403+
show_msg()

0 commit comments

Comments
 (0)
Please sign in to comment.