Skip to content

Commit 291000b

Browse files
committed
fix bug for linear
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent 2406762 commit 291000b

File tree

3 files changed

+110
-41
lines changed

3 files changed

+110
-41
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+57-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
33
os.environ["USE_GAUDI2_SCALE"] = "True"
4-
os.environ.pop("USE_GAUDI2_SCALE") # gaudi2 scale does not work
4+
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
55
# os.environ["GRAPH_VISUALIZATION"] = "True"
66
import shutil
77
shutil.rmtree(".graph_dumps", ignore_errors=True)
@@ -14,12 +14,13 @@
1414
import torch.nn.functional as F
1515
import deepspeed
1616
import transformers
17-
from transformers import AutoModelForCausalLM, AutoTokenizer
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
1818
import habana_frameworks.torch.core as htcore
1919
import numpy as np
2020
import lm_eval
2121
import lm_eval.tasks
2222
import lm_eval.evaluator
23+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
2324

2425

2526
torch.set_grad_enabled(False)
@@ -110,11 +111,16 @@ def itrex_bootstrap_stderr(f, xs, iters):
110111
token=None,
111112
)
112113
else:
113-
user_model = AutoModelForCausalLM.from_pretrained(
114-
args.model,
115-
device_map='hpu',
116-
torch_dtype=model_dtype,
117-
)
114+
if args.load:
115+
config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype)
116+
with init_empty_weights():
117+
user_model = AutoModelForCausalLM.from_config(config)
118+
else:
119+
user_model = AutoModelForCausalLM.from_pretrained(
120+
args.model,
121+
device_map='hpu',
122+
torch_dtype=model_dtype,
123+
)
118124
elif re.search("chatglm", args.model.lower()):
119125
from models.modeling_chatglm import ChatGLMForConditionalGeneration
120126
user_model = ChatGLMForConditionalGeneration.from_pretrained(
@@ -126,13 +132,18 @@ def itrex_bootstrap_stderr(f, xs, iters):
126132
# print(user_model.transformer.output_layer.weight.dtype) # always fp16
127133
user_model.float() # static fp8 need float32 for graph compiler
128134
else:
129-
user_model = AutoModelForCausalLM.from_pretrained(
130-
args.model,
131-
trust_remote_code=args.trust_remote_code,
132-
revision=args.revision,
133-
device_map='hpu',
134-
torch_dtype=model_dtype,
135-
)
135+
if args.load:
136+
config = AutoConfig.from_pretrained(args.model, torch_dtype=model_dtype)
137+
with init_empty_weights():
138+
user_model = AutoModelForCausalLM.from_config(config)
139+
else:
140+
user_model = AutoModelForCausalLM.from_pretrained(
141+
args.model,
142+
trust_remote_code=args.trust_remote_code,
143+
revision=args.revision,
144+
device_map='hpu',
145+
torch_dtype=model_dtype,
146+
)
136147

137148
# tokenizer
138149
if re.search("baichuan", args.model.lower()):
@@ -219,11 +230,40 @@ def replace_torch_mm_bmm():
219230
_check_params_as_const(user_model)
220231
# saving
221232
user_model.save("saved_results")
222-
print(user_model, flush=True)
233+
#print(user_model, flush=True)
234+
def show_msg():
235+
import numpy as np
236+
import glob
237+
from habana_frameworks.torch.hpu import memory_stats
238+
print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*")))
239+
mem_stats = memory_stats()
240+
mem_dict = {
241+
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
242+
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
243+
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
244+
}
245+
for k, v in mem_dict.items():
246+
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
247+
show_msg()
223248

224249
if args.load:
250+
def show_msg():
251+
import numpy as np
252+
import glob
253+
from habana_frameworks.torch.hpu import memory_stats
254+
print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*")))
255+
mem_stats = memory_stats()
256+
mem_dict = {
257+
"memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2),
258+
"max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2),
259+
"total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2),
260+
}
261+
for k, v in mem_dict.items():
262+
print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v))
263+
show_msg()
225264
from neural_compressor.torch.quantization import load
226265
user_model = load(user_model, "saved_results")
266+
show_msg()
227267
# replace torch.matmul and toch.bmm by injection
228268
def replace_torch_mm_bmm():
229269
from neural_compressor.torch.amp.fp8.functions import fp8_matmul
@@ -235,7 +275,8 @@ def replace_torch_mm_bmm():
235275
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
236276
_mark_params_as_const(user_model) # can reduce memory allocated and speed up
237277
_check_params_as_const(user_model)
238-
print(user_model, flush=True)
278+
#print(user_model, flush=True)
279+
show_msg()
239280

240281
if args.to_graph:
241282
import habana_frameworks.torch.hpu.graphs as htgraphs

neural_compressor/torch/algorithms/habana_fp8/modules.py

+51-24
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,10 @@ class FP8Linear(torch.nn.Module):
212212
def __init__(self, org_module, dtype) -> None:
213213
super().__init__()
214214
# attributes
215-
org_module.to("hpu")
216-
self.dtype = dtype
217-
self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX
218215
self.in_features = org_module.in_features
219216
self.out_features = org_module.out_features
217+
self.dtype = dtype
218+
self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX
220219
self.weight_dtype = self.dtype
221220
self.out_dtype = org_module.weight.dtype
222221
self.register_buffer(
@@ -228,50 +227,78 @@ def __init__(self, org_module, dtype) -> None:
228227
dtype=self.weight_dtype,
229228
),
230229
)
230+
if org_module.bias is not None:
231+
self.register_buffer(
232+
"bias",
233+
torch.empty(
234+
self.out_features,
235+
device="hpu",
236+
dtype=self.out_dtype,
237+
),
238+
)
239+
else:
240+
self.bias = None
241+
input_scale = _map_guadi2_scale(org_module.scale) if hasattr(org_module, "scale") else torch.tensor(1.0)
231242
self.register_buffer(
232-
"bias",
233-
torch.empty(
234-
self.out_features,
243+
"input_scale",
244+
torch.tensor(
245+
input_scale,
235246
device="hpu",
236-
dtype=self.out_dtype,
247+
dtype=torch.float32,
237248
),
238249
)
239-
scale = org_module.scale if hasattr(org_module, "scale") else 1.0
240250
self.register_buffer(
241-
"scale",
251+
"input_scale_inv",
242252
torch.tensor(
243-
scale,
253+
torch.reciprocal(input_scale),
244254
device="hpu",
245255
dtype=torch.float32,
246256
),
247257
)
248-
249-
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
250-
self.weight_scale = _map_guadi2_scale(self.weight_scale)
251-
self.weight_scale_inv = torch.reciprocal(self.weight_scale)
252-
self.weight.data.copy_(
253-
torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0]
258+
if not org_module.weight.device.type == "meta":
259+
weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
260+
weight_scale = _map_guadi2_scale(weight_scale)
261+
else:
262+
weight_scale = torch.tensor(1.0)
263+
self.register_buffer(
264+
"weight_scale",
265+
torch.tensor(
266+
weight_scale,
267+
device="hpu",
268+
dtype=torch.float32,
269+
),
270+
)
271+
self.register_buffer(
272+
"weight_scale_inv",
273+
torch.tensor(
274+
torch.reciprocal(weight_scale),
275+
device="hpu",
276+
dtype=torch.float32,
277+
),
254278
)
279+
# copy weight and bias
280+
if not org_module.weight.device.type == "meta":
281+
org_module.to("hpu")
282+
self.weight.data.copy_(
283+
torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0]
284+
)
285+
if org_module.bias is not None:
286+
self.bias.data.copy_(org_module.bias.data.type(self.out_dtype))
255287

256-
if org_module.bias is not None:
257-
self.bias.data.copy_(org_module.bias.data.type(self.out_dtype))
258-
else:
259-
self.bias = None
260288

261289
def forward(self, inp):
262290
assert inp.shape[-1] == self.in_features, "GEMM not possible"
263291
org_middle_shape = inp.shape[1:-1]
264292
inp = inp.view((-1, self.in_features))
265-
inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.scale, False, False, self.dtype)[0]
266-
self.scale_inv = torch.reciprocal(self.scale)
293+
inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.input_scale, False, False, self.dtype)[0]
267294
out = torch.ops.hpu.fp8_gemm_v2(
268295
inp,
269296
False,
270297
self.weight,
271298
True,
272299
None,
273300
self.out_dtype,
274-
self.scale_inv, # inv is used for recover scale
301+
self.input_scale_inv, # inv is used for recover scale
275302
self.weight_scale_inv,
276303
self.bias,
277304
False,
@@ -284,7 +311,7 @@ def extra_repr(self) -> str:
284311
self.in_features,
285312
self.out_features,
286313
self.bias is not None,
287-
self.scale,
314+
self.input_scale,
288315
self.dtype,
289316
)
290317

neural_compressor/torch/algorithms/habana_fp8/save_load.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def load(model, output_dir="./saved_results"):
9292
module = FP8Cast(dtype=dtype)
9393
set_module(model, op_name, module)
9494
htcore.mark_step()
95-
model.load_state_dict(stat_dict)
95+
model.load_state_dict(stat_dict, assign=True)
96+
model.to('hpu')
9697
htcore.mark_step()
9798
logger.info("Quantized model loading successful.")
9899
return model

0 commit comments

Comments
 (0)