Skip to content

Commit 2b86e50

Browse files
authored
support hardware scale for gaudi2 (#1637)
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent e6664b0 commit 2b86e50

File tree

3 files changed

+15
-5
lines changed
  • examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8
  • neural_compressor/torch

3 files changed

+15
-5
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import os
22
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
33
os.environ["USE_GAUDI2_SCALE"] = "True"
4-
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
4+
# USE_GAUDI2_SCALE requires PT_USE_FP8_AMAX for torch.mm/bmm, or got failure
5+
os.environ["PT_USE_FP8_AMAX"] = "True"
56
# os.environ["GRAPH_VISUALIZATION"] = "True"
67
# import shutil
78
# shutil.rmtree(".graph_dumps", ignore_errors=True)
@@ -173,7 +174,7 @@
173174
args.model,
174175
trust_remote_code=args.trust_remote_code
175176
)
176-
177+
tokenizer.pad_token = tokenizer.eos_token
177178

178179
user_model.eval()
179180

@@ -219,6 +220,7 @@ def calib_func(model):
219220

220221
user_model = quantize(user_model, qconfig, calib_func, inplace=True)
221222
# saving
223+
print(user_model)
222224
if args.save and local_rank in [-1, 0]:
223225
user_model.save("saved_results")
224226

neural_compressor/torch/algorithms/habana_fp8/modules.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def forward(self, x):
5555

5656
##################### FP8 modules #######################
5757
def _map_guadi2_scale(scale):
58-
USE_GAUDI2_SCALE = os.environ.get("USE_GAUDI2_SCALE")
58+
USE_GAUDI2_SCALE = bool(os.getenv("USE_GAUDI2_SCALE", False))
5959
if USE_GAUDI2_SCALE:
6060
scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256])
6161
for i in scale_list:
@@ -135,6 +135,7 @@ def forward(self, inp):
135135
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
136136
if self.use_amax:
137137
input_scale = self.dtype_amax / inp.abs().max()
138+
input_scale = _map_guadi2_scale(input_scale)
138139
input_scale_inv = torch.reciprocal(input_scale)
139140
else:
140141
input_scale, input_scale_inv = None, None
@@ -183,6 +184,7 @@ def forward(self, input1, input2):
183184
self.out_dtype = input1.dtype
184185
if self.use_amax:
185186
input1_scale = self.dtype_amax / input1.data.abs().max()
187+
input1_scale = _map_guadi2_scale(input1_scale)
186188
input1_scale_inv = torch.reciprocal(input1_scale)
187189
else:
188190
input1_scale, input1_scale_inv = None, None
@@ -195,6 +197,7 @@ def forward(self, input1, input2):
195197
self.out_dtype = input2.dtype
196198
if self.use_amax:
197199
input2_scale = self.dtype_amax / input2.data.abs().max()
200+
input2_scale = _map_guadi2_scale(input2_scale)
198201
input2_scale_inv = torch.reciprocal(input2_scale)
199202
else:
200203
input2_scale, input2_scale_inv = None, None

neural_compressor/torch/amp/fp8/functions.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import torch
2121
from torch.nn import functional as F
2222

23-
from neural_compressor.common import logger
23+
from neural_compressor.torch.algorithms.habana_fp8.modules import _map_guadi2_scale
24+
from neural_compressor.torch.utils import logger
2425

2526
_F_linear = F.linear
2627
_torch_matmul = torch.matmul
@@ -32,7 +33,7 @@
3233
E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("hpu")
3334

3435
DTYPE_AMAX = E4M3_AMAX if DATA_TYPE == torch.float8_e4m3fn else E5M2_AMAX
35-
USE_AMAX = False if os.getenv("PT_USE_FP8_AMAX") is None else True
36+
USE_AMAX = bool(os.getenv("PT_USE_FP8_AMAX", False))
3637

3738

3839
def fp8_linear_forward(input, weight, bias=None):
@@ -44,6 +45,7 @@ def fp8_linear_forward(input, weight, bias=None):
4445
out_dtype = input.dtype
4546
if USE_AMAX:
4647
input_scale = DTYPE_AMAX / input.data.abs().max()
48+
input_scale = _map_guadi2_scale(input_scale)
4749
input_scale_inv = torch.reciprocal(input_scale)
4850
else:
4951
input_scale, input_scale_inv = None, None
@@ -56,6 +58,7 @@ def fp8_linear_forward(input, weight, bias=None):
5658
out_dtype = weight.dtype
5759
if USE_AMAX:
5860
weight_scale = DTYPE_AMAX / weight.data.abs().max()
61+
weight_scale = _map_guadi2_scale(weight_scale)
5962
weight_scale_inv = torch.reciprocal(weight_scale)
6063
else:
6164
weight_scale, weight_scale_inv = None, None
@@ -86,6 +89,7 @@ def fp8_matmul(input1, input2):
8689
out_dtype = input1.dtype
8790
if USE_AMAX:
8891
input1_scale = DTYPE_AMAX / input1.data.abs().max()
92+
input1_scale = _map_guadi2_scale(input1_scale)
8993
input1_scale_inv = torch.reciprocal(input1_scale)
9094
else:
9195
input1_scale, input1_scale_inv = None, None
@@ -98,6 +102,7 @@ def fp8_matmul(input1, input2):
98102
out_dtype = input2.dtype
99103
if USE_AMAX:
100104
input2_scale = DTYPE_AMAX / input2.data.abs().max()
105+
input2_scale = _map_guadi2_scale(input2_scale)
101106
input2_scale_inv = torch.reciprocal(input2_scale)
102107
else:
103108
input2_scale, input2_scale_inv = None, None

0 commit comments

Comments
 (0)