Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update fp8 implementation, design and implement save&load #1605

Merged
merged 22 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
import os
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
os.environ["USE_GAUDI2_SCALE"] = "True"
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
# os.environ["GRAPH_VISUALIZATION"] = "True"
# import shutil
# shutil.rmtree(".graph_dumps", ignore_errors=True)
import argparse
import time
import json
import re
import torch
import transformers
import os
import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer
import habana_frameworks.torch.hpex
from habana_frameworks.torch.hpu import memory_stats
import torch.nn.functional as F
import deepspeed
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import habana_frameworks.torch.core as htcore
import numpy as np
import lm_eval
import lm_eval.tasks
import lm_eval.evaluator
torch.set_grad_enabled(False)
from accelerate import init_empty_weights
from utils import itrex_bootstrap_stderr, show_msg, save_to_excel


def itrex_bootstrap_stderr(f, xs, iters):
from lm_eval.metrics import _bootstrap_internal, sample_stddev
res = []
chunk_size = min(1000, iters)
it = _bootstrap_internal(f, chunk_size)
for i in range(iters // chunk_size):
bootstrap = it((i, xs))
res.extend(bootstrap)
return sample_stddev(res)
torch.set_grad_enabled(False)
htcore.hpu_set_env()
torch.device('hpu')


# to avoid out-of-memory caused by Popen for large language models.
lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr
Expand All @@ -51,22 +54,26 @@ def itrex_bootstrap_stderr(f, xs, iters):
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--performance", action="store_true")
parser.add_argument("--generate", action="store_true")
parser.add_argument("--skip_fp8_mm", action="store_true")
parser.add_argument("--dump_to_excel", action="store_true")
parser.add_argument("--save", action="store_true")
parser.add_argument("--load", action="store_true")
parser.add_argument("--batch_size", default=1, type=int,
help="For accuracy measurement only.")
parser.add_argument("--pad_max_length", default=512, type=int,
help="Pad input ids to max length.")
parser.add_argument("--calib_iters", default=100, type=int,
help="calibration iters.")
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
choices=["winogrande", "copa", "piqa", "rte", "hellaswag", \
"openbookqa", "lambada_openai", "lambada_standard", "wikitext"],
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \
type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa",
"rte", "openbookqa", "lambada_standard", "wikitext"],
help="tasks list for accuracy validation")
parser.add_argument("--limit", default=None, type=int,
help="the sample num of evaluation.")
parser.add_argument("--max_new_tokens", default=100, type=int,
help="calibration iters.")
parser.add_argument('--buckets', type=int, nargs='+', \
help="Input length buckets to use with static_shapes", default=[129])
help="Input length buckets to use with static_shapes", default=[256, 512])
parser.add_argument("--local_rank",
type=int,
default=-1,
Expand All @@ -78,67 +85,65 @@ def itrex_bootstrap_stderr(f, xs, iters):
world_size = int(os.getenv('WORLD_SIZE', '1'))
local_rank = int(os.getenv('LOCAL_RANK', '-1'))

#if local_rank == 0:
# os.environ["ENABLE_CONSOLE"] = 'True'
# os.environ["LOG_LEVEL_ALL"] = '0'

# model
model_dtype = torch.float32
if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()):
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
torch.device('hpu')
config = AutoConfig.from_pretrained(args.model)
if world_size > 1:
model_dtype = torch.bfloat16
config = AutoConfig.from_pretrained(args.model)
model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype
deepspeed.init_distributed(dist_backend="hccl")
with deepspeed.OnDevice(dtype=model_dtype, device="meta"):
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
import tempfile
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
from utils import write_checkpoints_json
from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana
write_checkpoints_json(
args.model,
local_rank,
checkpoints_json,
token=None,
args.model,
local_rank,
checkpoints_json,
token=None,
)
elif re.search("llama", args.model.lower()):
from models.modeling_llama import LlamaForCausalLM
user_model = LlamaForCausalLM.from_pretrained(
else:
if args.load:
config = AutoConfig.from_pretrained(args.model)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map='hpu',
torch_dtype=model_dtype,
)
elif re.search("chatglm", args.model.lower()):
if args.load:
config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config)
else:
from models.modeling_chatglm import ChatGLMForConditionalGeneration
user_model = ChatGLMForConditionalGeneration.from_pretrained(
args.model,
revision=args.revision,
device_map='hpu',
torch_dtype=model_dtype,
)
# print(user_model.transformer.output_layer.weight.dtype) # always fp16
user_model.float() # static fp8 need float32 for graph compiler
else:
if args.load:
config = AutoConfig.from_pretrained(args.model)
with init_empty_weights():
user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
device_map='hpu',
torch_dtype=model_dtype,
)
elif re.search("chatglm", args.model.lower()):
from models.modeling_chatglm import ChatGLMForConditionalGeneration
user_model = ChatGLMForConditionalGeneration.from_pretrained(
args.model,
revision=args.revision,
device_map='hpu',
)
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
device_map='hpu',
)

# tokenizer
if re.search("baichuan", args.model.lower()):
from models.tokenization_baichuan import BaichuanTokenizer
tokenizer = BaichuanTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)

if world_size > 1:
if re.search("llama", args.model.lower()):
Expand All @@ -148,36 +153,44 @@ def itrex_bootstrap_stderr(f, xs, iters):
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")}
ds_inference_kwargs["checkpoint"] = checkpoints_json.name

ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs)
else:
ds_model = deepspeed.init_inference(user_model,
mp_size=world_size,
replace_with_kernel_inject=False)
user_model = ds_model.module


# tokenizer
if re.search("baichuan", args.model.lower()):
from models.tokenization_baichuan import BaichuanTokenizer
tokenizer = BaichuanTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model,
trust_remote_code=args.trust_remote_code
)


user_model.eval()

if args.approach in ["dynamic", "static"]:

### dynamic & static quantization ###
if args.approach in ["dynamic", "static"] and not args.load:
print("device:", next(user_model.parameters()).device)
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
from neural_compressor.torch.quantization import quantize
if args.precision == "fp8_e4m3":
dtype = torch.float8_e4m3fn
else:
dtype = torch.float8_e5m2
dtype = args.precision
if args.approach == "dynamic":
#user_model = quantize_dynamic(user_model, dtype, inplace=True)
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
if args.skip_lm_head:
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
qconfig.set_local("lm_head", fp32_config)
user_model = quantize_dynamic(user_model, qconfig, inplace=True)
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
user_model = quantize_dynamic(user_model, dtype, inplace=True)
elif args.approach == "static":
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static")
qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static")
if args.skip_lm_head:
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32")
qconfig.set_local("lm_head", fp32_config)
# dataset
from datasets import load_dataset
Expand All @@ -186,7 +199,13 @@ def itrex_bootstrap_stderr(f, xs, iters):
calib_data = []
for examples in calib_dataset:
calib_data.append(
tokenizer(examples["text"], return_tensors="pt", max_length=128)
tokenizer(
examples["text"],
return_tensors="pt",
max_length=64,
padding="max_length",
truncation=True
)
)

def calib_func(model):
Expand All @@ -199,12 +218,46 @@ def calib_func(model):
)

user_model = quantize(user_model, qconfig, calib_func, inplace=True)
print(user_model, flush=True)
# saving
if args.save and local_rank in [-1, 0]:
user_model.save("saved_results")


if args.load:
from neural_compressor.torch.quantization import load
user_model = load(user_model, "saved_results")


if args.approach in ["dynamic", "static"] or args.load:
# It enables weights constant folding
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
_mark_params_as_const(user_model) # can reduce memory allocated and speed up
_check_params_as_const(user_model)



# If torch.matmul and torch.bmm are not replaced by INC module,
# Below codes can make torch.matmul and torch.bmm run on fp8 by injection.
if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']:
def replace_torch_mm_bmm():
from neural_compressor.torch.amp.fp8.functions import fp8_matmul
torch.matmul = fp8_matmul
torch.bmm = fp8_matmul

replace_torch_mm_bmm()


# inference optimization
if args.to_graph:
import habana_frameworks.torch.hpu.graphs as htgraphs
user_model = htgraphs.wrap_in_hpu_graph(user_model)


# dump message of HPU after quantization or reloading
show_msg()


### generation, performance and accuracy validation ###
if args.generate:
input_prompt = "Here is my prompt"
print("Prompt sentence:", input_prompt)
Expand Down Expand Up @@ -234,6 +287,7 @@ def calib_func(model):
print("Generated sentence:", output_sentence)
print("Duration:", eval_end - eval_start)


if args.performance:
eval_start = time.perf_counter()
input_prompt = "Intel is a company which"
Expand All @@ -242,6 +296,7 @@ def calib_func(model):
outputs = user_model.generate(input_tokens, **generation_config)
print("Duration of generating 100 tokens :", time.perf_counter() - eval_start)


if args.accuracy:

class HabanaModelAdapter(lm_eval.base.BaseLM):
Expand Down Expand Up @@ -292,16 +347,14 @@ def find_bucket(self, length):
return [b for b in self.buckets if b >= length][0]

def _model_call(self, inps):
#print(inps.shape)
seq_length = inps.shape[-1]
padding_length = 0
bucket_length = self.find_bucket(seq_length)
padding_length = bucket_length - seq_length
if True:
import torch.nn.functional as F
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id)
logits = self.model(inps.to(self._device))["logits"].cpu()

logits = self.model(inps.to(self._device))['logits']
if True and padding_length > 0:
if padding_length > 0:
logits = logits[:, :-padding_length, :]
logits = logits.to(torch.float32)
return logits
Expand Down Expand Up @@ -333,18 +386,18 @@ def _model_call(self, inps):


dumped = json.dumps(results, indent=2)
accu_dict = {}
case_name = args.approach + "-" + args.precision
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True)
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]]
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True)
accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]]
if args.dump_to_excel and local_rank in [-1, 0]:
save_to_excel(accu_dict)


# show memory usage
mem_stats = memory_stats()
mem_dict = {
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
}
for k, v in mem_dict.items():
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
# dump final message of HPU
show_msg()
Loading
Loading