Skip to content

Commit f9c14b2

Browse files
committed
design and implement save&load
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent e40fa02 commit f9c14b2

File tree

14 files changed

+1900
-131
lines changed

14 files changed

+1900
-131
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def itrex_bootstrap_stderr(f, xs, iters):
166166

167167
if args.approach in ["dynamic", "static"]:
168168
print("device:", next(user_model.parameters()).device)
169-
from neural_compressor.torch.quantization.config import FP8QConfig, get_default_fp8_qconfig
169+
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
170170
from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic
171171
from neural_compressor.torch.quantization import quantize
172172
if args.precision == "fp8_e4m3":
@@ -175,15 +175,15 @@ def itrex_bootstrap_stderr(f, xs, iters):
175175
dtype = torch.float8_e5m2
176176
if args.approach == "dynamic":
177177
#user_model = quantize_dynamic(user_model, dtype, inplace=True)
178-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
178+
qconfig = FP8Config(weight_dtype=dtype, act_dtype=dtype, approach="dynamic")
179179
if args.skip_lm_head:
180-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
180+
fp32_config = FP8Config(weight_dtype=torch.float32, act_dtype=torch.float32)
181181
qconfig.set_local("lm_head", fp32_config)
182182
user_model = quantize_dynamic(user_model, qconfig, inplace=True)
183183
elif args.approach == "static":
184-
qconfig = FP8QConfig(weight_dtype=dtype, act_dtype=dtype, approach="static")
184+
qconfig = FP8Config(weight_dtype=dtype, act_dtype=dtype, approach="static")
185185
if args.skip_lm_head:
186-
fp32_config = FP8QConfig(weight_dtype=torch.float32, act_dtype=torch.float32)
186+
fp32_config = FP8Config(weight_dtype=torch.float32, act_dtype=torch.float32)
187187
qconfig.set_local("lm_head", fp32_config)
188188
# dataset
189189
from datasets import load_dataset

neural_compressor/torch/algorithms/habana_fp8/fp8_quant.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818

1919
import habana_frameworks.torch.core as htcore
20+
from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const
2021
import torch
2122
from deepspeed.module_inject import LinearAllreduce, LinearLayer
2223
from deepspeed.module_inject.layers import LmHeadLinearAllreduce
@@ -40,6 +41,10 @@
4041
FP8LinearLayer,
4142
FP8LmHeadLinearAllreduce,
4243
FP8Matmul,
44+
# dtype amax
45+
E4M3_AMAX,
46+
E5M2_AMAX,
47+
_map_guadi2_scale,
4348
)
4449

4550
quantization_mapping = {
@@ -55,20 +60,20 @@
5560
white_list = tuple(quantization_mapping.keys())
5661

5762

58-
# without scale factor 0.9, the output will be abnormal.
59-
E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu")
60-
E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu")
61-
FP8_DTYPE = [torch.float8_e5m2, torch.float8_e4m3fn]
63+
FP8_DTYPE = [torch.float8_e5m2, torch.float8_e4m3fn, "fp8_e5m2", "fp8_e4m3"]
64+
dtype_mapping = {"fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn}
65+
# enable inference optimizations
66+
htcore.hpu_initialize()
6267

6368

6469
def _replace_module(module, qconfig):
6570
if qconfig.approach == "static":
6671
if isinstance(module, white_list):
6772
QModule = quantization_mapping[type(module)]
68-
assert qconfig.weight_dtype == qconfig.act_dtype, "weight and activation should be the same dtype."
69-
module = QModule(module, qconfig.act_dtype)
73+
assert qconfig.w_dtype == qconfig.act_dtype, "weight and activation should be the same dtype."
74+
module = QModule(module, dtype_mapping[qconfig.act_dtype])
7075
elif qconfig.approach == "dynamic":
71-
dtype = qconfig.act_dtype
76+
dtype = dtype_mapping[qconfig.act_dtype]
7277
if isinstance(module, torch.nn.Linear):
7378
# need module for initialization
7479
module = FP8DynamicLinear(module, dtype)
@@ -84,6 +89,8 @@ def _replace_module(module, qconfig):
8489

8590
def quantize_dynamic(model, dtype=torch.float8_e4m3fn, inplace=True):
8691
q_model = model if inplace else copy.deepcopy(model)
92+
if isinstance(dtype, str):
93+
dtype = dtype_mapping[dtype]
8794
for n, m in q_model.named_modules():
8895
if isinstance(m, torch.nn.Linear):
8996
new_m = FP8DynamicLinear(m, dtype) # need m for init
@@ -98,6 +105,8 @@ def quantize_dynamic(model, dtype=torch.float8_e4m3fn, inplace=True):
98105
new_m = FP8Cast(dtype=dtype)
99106
set_module(q_model, n, new_m)
100107
htcore.mark_step()
108+
_mark_params_as_const(q_model)
109+
_check_params_as_const(q_model)
101110
return q_model
102111

103112

@@ -133,7 +142,7 @@ def _remove_observer(module, qconfig):
133142
import deepspeed.comm as dist
134143
from torch.distributed import ReduceOp
135144

136-
HF_max = E4M3_AMAX if qconfig.act_dtype == torch.float8_e4m3fn else E5M2_AMAX
145+
HF_max = E4M3_AMAX if qconfig.act_dtype == "fp8_e4m3" else E5M2_AMAX
137146
if hasattr(module, "input_activation_post_process"):
138147
if hasattr(module.input_activation_post_process, "_non_linear_param_search"): # kl
139148
min_val, max_val = module.input_activation_post_process._non_linear_param_search()
@@ -145,7 +154,11 @@ def _remove_observer(module, qconfig):
145154
amax = amax.to("hpu")
146155
dist.all_reduce(amax, op=ReduceOp.MAX)
147156
scale = HF_max / amax
148-
module.register_parameter("scale", torch.nn.Parameter(scale))
157+
scale = _map_guadi2_scale(scale)
158+
if hasattr(module, "input_activation_post_process1"):
159+
module.register_parameter("scale1", torch.nn.Parameter(scale))
160+
else:
161+
module.register_parameter("scale", torch.nn.Parameter(scale))
149162
delattr(module, "input_activation_post_process")
150163
if hasattr(module, "input_activation_post_process1"):
151164
if hasattr(module.input_activation_post_process1, "_non_linear_param_search"):
@@ -158,7 +171,8 @@ def _remove_observer(module, qconfig):
158171
amax = amax.to("hpu")
159172
dist.all_reduce(amax, op=ReduceOp.MAX)
160173
scale = HF_max / amax
161-
module.register_parameter("scale1", torch.nn.Parameter(scale))
174+
scale = _map_guadi2_scale(scale)
175+
module.register_parameter("scale2", torch.nn.Parameter(scale))
162176
delattr(module, "input_activation_post_process1")
163177

164178
# remove observer hooks
@@ -175,7 +189,7 @@ def prepare(model, qconfig_mapping):
175189
for (op_name, op_type), qconfig in qconfig_mapping.items():
176190
if qconfig.approach == "dynamic":
177191
continue
178-
if qconfig.weight_dtype not in FP8_DTYPE:
192+
if qconfig.w_dtype not in FP8_DTYPE:
179193
continue
180194
module = fetch_module(model, op_name)
181195
if module is None:
@@ -188,7 +202,7 @@ def prepare(model, qconfig_mapping):
188202

189203
def convert(model, qconfig_mapping):
190204
for (op_name, op_type), qconfig in qconfig_mapping.items():
191-
if qconfig.weight_dtype not in FP8_DTYPE:
205+
if qconfig.w_dtype not in FP8_DTYPE:
192206
continue
193207
module = fetch_module(model, op_name)
194208
if module is None:
@@ -211,4 +225,6 @@ def quantize(model, qconfig_mapping, run_fn=None, run_args=None, inplace=True):
211225
else:
212226
run_fn(q_model)
213227
q_model = convert(q_model, qconfig_mapping)
228+
_mark_params_as_const(q_model)
229+
_check_params_as_const(q_model)
214230
return q_model

neural_compressor/torch/algorithms/habana_fp8/modules.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ def forward(self, x):
5353

5454

5555
##################### FP8 modules #######################
56+
def _map_guadi2_scale(scale):
57+
USE_GUADI2_SCALE = os.environ.get("USE_GUADI2_SCALE")
58+
if USE_GUADI2_SCALE:
59+
scale_list = torch.tensor([16, 1, 1/16, 1/256])
60+
for i in scale_list:
61+
if scale > i or i == torch.tensor(1/256):
62+
return i
63+
else:
64+
return scale
65+
66+
5667
class FP8DynamicLinear(torch.nn.Module):
5768
def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None:
5869
super().__init__()
@@ -86,6 +97,7 @@ def __init__(self, org_module, dtype=torch.float8_e4m3fn) -> None:
8697
# scale = HF_max /amax
8798
if self.use_amax:
8899
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
100+
self.weight_scale = _map_guadi2_scale(self.weight_scale)
89101
self.weight_scale_inv = torch.reciprocal(self.weight_scale)
90102
else:
91103
self.weight_scale = None
@@ -233,9 +245,9 @@ def __init__(self, org_module, dtype) -> None:
233245
dtype=torch.float32,
234246
),
235247
)
236-
self.scale_inv = torch.reciprocal(self.scale)
237248

238249
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
250+
self.weight_scale = _map_guadi2_scale(self.weight_scale)
239251
self.weight_scale_inv = torch.reciprocal(self.weight_scale)
240252
self.weight.data.copy_(
241253
torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[0]
@@ -251,6 +263,7 @@ def forward(self, inp):
251263
org_middle_shape = inp.shape[1:-1]
252264
inp = inp.view((-1, self.in_features))
253265
inp = torch.ops.hpu.cast_to_fp8_v2(inp, self.scale, False, False, self.dtype)[0]
266+
self.scale_inv = torch.reciprocal(self.scale)
254267
out = torch.ops.hpu.fp8_gemm_v2(
255268
inp,
256269
False,
@@ -283,26 +296,24 @@ def __init__(self, org_module, dtype) -> None:
283296
self.dtype = dtype
284297
self.dtype_amax = E4M3_AMAX if self.dtype == torch.float8_e4m3fn else E5M2_AMAX
285298
self.out_dtype = torch.float32
286-
scale = org_module.scale if hasattr(org_module, "scale") else 1.0
287299
scale1 = org_module.scale1 if hasattr(org_module, "scale1") else 1.0
300+
scale2 = org_module.scale2 if hasattr(org_module, "scale2") else 1.0
288301
self.register_buffer(
289-
"scale",
302+
"scale1",
290303
torch.tensor(
291-
scale,
304+
scale1,
292305
device="hpu",
293306
dtype=self.out_dtype,
294307
),
295308
)
296309
self.register_buffer(
297-
"scale1",
310+
"scale2",
298311
torch.tensor(
299-
scale1,
312+
scale2,
300313
device="hpu",
301314
dtype=self.out_dtype,
302315
),
303316
)
304-
self.input1_scale_inv = torch.reciprocal(self.scale)
305-
self.input2_scale_inv = torch.reciprocal(self.scale1)
306317

307318
def forward(self, input1, input2):
308319
dim1 = input1.shape[-1]
@@ -311,12 +322,14 @@ def forward(self, input1, input2):
311322

312323
if input1.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
313324
self.out_dtype = input1.dtype
314-
input1 = torch.ops.hpu.cast_to_fp8_v2(input1, self.scale, False, False, self.dtype)[0]
325+
input1 = torch.ops.hpu.cast_to_fp8_v2(input1, self.scale1, False, False, self.dtype)[0]
326+
self.input1_scale_inv = torch.reciprocal(self.scale1)
315327
else:
316328
self.input1_scale_inv = None
317329
if input2.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
318330
self.out_dtype = input2.dtype
319-
input2 = torch.ops.hpu.cast_to_fp8_v2(input2, self.scale1, False, False, self.dtype)[0]
331+
input2 = torch.ops.hpu.cast_to_fp8_v2(input2, self.scale2, False, False, self.dtype)[0]
332+
self.input2_scale_inv = torch.reciprocal(self.scale2)
320333
else:
321334
self.input2_scale_inv = None
322335
out = torch.ops.hpu.fp8_gemm_v2(
@@ -407,10 +420,10 @@ def __init__(self, org_module, dtype) -> None:
407420
dtype=torch.float32,
408421
),
409422
)
410-
self.scale_inv = 1.0 / self.scale
411423
# user configuration
412424
# scale = HF_max /amax
413425
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
426+
self.weight_scale = _map_guadi2_scale(self.weight_scale)
414427
self.weight_scale_inv = 1.0 / self.weight_scale
415428
self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[
416429
0
@@ -432,6 +445,7 @@ def forward(self, inp):
432445
assert inp.shape[-1] == self.in_features, "GEMM not possible"
433446
inputmat = inp.view((-1, self.in_features))
434447
inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0]
448+
self.scale_inv = torch.reciprocal(self.scale)
435449
out = torch.ops.hpu.fp8_gemm_v2(
436450
inputmat,
437451
False,
@@ -487,10 +501,10 @@ def __init__(self, org_module, dtype) -> None:
487501
dtype=torch.float32,
488502
),
489503
)
490-
self.scale_inv = 1.0 / self.scale
491504
# user configuration
492505
# scale = HF_max /amax
493506
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
507+
self.weight_scale = _map_guadi2_scale(self.weight_scale)
494508
self.weight_scale_inv = 1.0 / self.weight_scale
495509
self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[
496510
0
@@ -513,6 +527,7 @@ def forward(self, inp):
513527
assert inp.shape[-1] == self.in_features, "GEMM not possible"
514528
inputmat = inp.view((-1, self.in_features))
515529
inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0]
530+
self.scale_inv = torch.reciprocal(self.scale)
516531
out = torch.ops.hpu.fp8_gemm_v2(
517532
inputmat,
518533
False,
@@ -572,10 +587,10 @@ def __init__(self, org_module, dtype) -> None:
572587
dtype=torch.float32,
573588
),
574589
)
575-
self.scale_inv = 1.0 / self.scale
576590
# user configuration
577591
# scale = HF_max /amax
578592
self.weight_scale = self.dtype_amax / org_module.weight.data.abs().max()
593+
self.weight_scale = _map_guadi2_scale(self.weight_scale)
579594
self.weight_scale_inv = 1.0 / self.weight_scale
580595
self.weight = torch.ops.hpu.cast_to_fp8_v2(org_module.weight.data, self.weight_scale, False, False, self.dtype)[
581596
0
@@ -608,6 +623,7 @@ def forward(self, inp):
608623
input_shard = inp.shape[-1] // self.world_size
609624
inputmat = inp[:, :, self.rank * input_shard : (self.rank + 1) * input_shard]
610625
inputmat = torch.ops.hpu.cast_to_fp8_v2(inputmat, self.scale, False, False, self.dtype)[0]
626+
self.scale_inv = torch.reciprocal(self.scale)
611627
out = torch.ops.hpu.fp8_gemm_v2(
612628
inputmat,
613629
False,

neural_compressor/torch/algorithms/habana_fp8/observer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
import torch
1818
from torch.ao.quantization.observer import *
1919

20-
# without scale factor 0.9, the output will be abnormal.
21-
E4M3_AMAX = torch.tensor(240 * 0.9, dtype=torch.float).to("hpu")
22-
E5M2_AMAX = torch.tensor(57344 * 0.9, dtype=torch.float).to("hpu")
20+
from .modules import E4M3_AMAX, E5M2_AMAX
2321

2422

2523
class FP8HistogramObserver(HistogramObserver):

neural_compressor/torch/algorithms/habana_fp8/save_load.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,27 @@
1414
FP8DynamicMatmul,
1515
FP8Cast,
1616
)
17+
from .fp8_quant import FP8_DTYPE, dtype_mapping
18+
1719

1820
def save(model, output_dir="./saved_results"):
1921
if not os.path.exists(output_dir):
2022
os.mkdir(output_dir)
2123
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "quantized_model.pt")
2224
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), "qconfig.json")
2325
# saving process
24-
torch.save(model.stat_dict(), qmodel_file_path)
25-
logger.info("Save state_dict of quantized model to {}.".format(qmodel_file_path))
2626
with open(qconfig_file_path, "w") as f:
2727
json.dump(model.qconfig, f, indent=4)
28+
29+
import fp8_convert
30+
stat_dict = {}
31+
for k, v in model.state_dict().items():
32+
if v.dtype in FP8_DTYPE:
33+
v = fp8_convert.to_u8(v.to('cpu'))
34+
stat_dict[k] = v.to('cpu')
35+
torch.save(stat_dict, qmodel_file_path)
36+
37+
logger.info("Save state_dict of quantized model to {}.".format(qmodel_file_path))
2838
logger.info("Save configuration of quantized model to {}.".format(qconfig_file_path))
2939

3040

@@ -36,12 +46,17 @@ def load(model, output_dir="./saved_results"):
3646
with open(qconfig_file_path, "r") as f:
3747
model_qconfig = json.load(f)
3848
# load quantization configuration
39-
from .fp8_quant import FP8_DTYPE
40-
for (op_name, op_type), op_qconfig in model_qconfig.items():
41-
dtype = op_qconfig['weight_dtype']
49+
stat_dict = torch.load(qmodel_file_path)
50+
import fp8_convert
51+
for op_name, op_qconfig in model_qconfig["per_module_qconfig"].items():
52+
dtype = op_qconfig['w_dtype']
53+
choice = 1 if dtype=="fp8_e4m3" else 0
54+
if op_name+".weight" in stat_dict:
55+
stat_dict[op_name+".weight"] = fp8_convert.from_u8(stat_dict[op_name+".weight"], choice)
4256
if dtype not in FP8_DTYPE:
4357
continue
4458
module = fetch_module(model, op_name)
59+
dtype = dtype_mapping[dtype]
4560
if op_qconfig['approach'] == "static":
4661
if isinstance(module, white_list):
4762
QModule = quantization_mapping[type(module)]
@@ -58,5 +73,7 @@ def load(model, output_dir="./saved_results"):
5873
module = FP8Cast(dtype=dtype)
5974
set_module(model, op_name, module)
6075
htcore.mark_step()
76+
model.load_state_dict(stat_dict)
77+
htcore.mark_step()
6178
logger.info("Quantized model loading successful.")
6279
return model
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)