1
+ import os
2
+ os .environ ["EXPERIMENTAL_WEIGHT_SHARING" ] = "False"
3
+ os .environ ["USE_GAUDI2_SCALE" ] = "True"
4
+ os .environ .pop ("USE_GAUDI2_SCALE" ) # gaudi scale work
5
+ # os.environ["GRAPH_VISUALIZATION"] = "True"
6
+ # import shutil
7
+ # shutil.rmtree(".graph_dumps", ignore_errors=True)
1
8
import argparse
2
9
import time
3
10
import json
4
11
import re
5
12
import torch
6
- import transformers
7
- import os
8
- import deepspeed
9
- from transformers import AutoModelForCausalLM , AutoTokenizer
10
13
import habana_frameworks .torch .hpex
11
- from habana_frameworks .torch .hpu import memory_stats
14
+ import torch .nn .functional as F
15
+ import deepspeed
16
+ import transformers
17
+ from transformers import AutoModelForCausalLM , AutoTokenizer , AutoConfig
18
+ import habana_frameworks .torch .core as htcore
12
19
import numpy as np
13
20
import lm_eval
14
21
import lm_eval .tasks
15
22
import lm_eval .evaluator
16
- torch .set_grad_enabled (False )
23
+ from accelerate import init_empty_weights
24
+ from utils import itrex_bootstrap_stderr , show_msg , save_to_excel
17
25
18
26
19
- def itrex_bootstrap_stderr (f , xs , iters ):
20
- from lm_eval .metrics import _bootstrap_internal , sample_stddev
21
- res = []
22
- chunk_size = min (1000 , iters )
23
- it = _bootstrap_internal (f , chunk_size )
24
- for i in range (iters // chunk_size ):
25
- bootstrap = it ((i , xs ))
26
- res .extend (bootstrap )
27
- return sample_stddev (res )
27
+ torch .set_grad_enabled (False )
28
+ htcore .hpu_set_env ()
29
+ torch .device ('hpu' )
30
+
28
31
29
32
# to avoid out-of-memory caused by Popen for large language models.
30
33
lm_eval .metrics .bootstrap_stderr = itrex_bootstrap_stderr
@@ -51,22 +54,26 @@ def itrex_bootstrap_stderr(f, xs, iters):
51
54
parser .add_argument ("--accuracy" , action = "store_true" )
52
55
parser .add_argument ("--performance" , action = "store_true" )
53
56
parser .add_argument ("--generate" , action = "store_true" )
57
+ parser .add_argument ("--skip_fp8_mm" , action = "store_true" )
58
+ parser .add_argument ("--dump_to_excel" , action = "store_true" )
59
+ parser .add_argument ("--save" , action = "store_true" )
60
+ parser .add_argument ("--load" , action = "store_true" )
54
61
parser .add_argument ("--batch_size" , default = 1 , type = int ,
55
62
help = "For accuracy measurement only." )
56
63
parser .add_argument ("--pad_max_length" , default = 512 , type = int ,
57
64
help = "Pad input ids to max length." )
58
65
parser .add_argument ("--calib_iters" , default = 100 , type = int ,
59
66
help = "calibration iters." )
60
- parser .add_argument ("--tasks" , nargs = '+' , default = ["lambada_openai" ], type = str , \
61
- choices = ["winogrande " , "copa " , "piqa" , "rte " , "hellaswag " , \
62
- "openbookqa " , "lambada_openai " , "lambada_standard" , "wikitext" ],
67
+ parser .add_argument ("--tasks" , nargs = '+' , default = ["lambada_openai" ], \
68
+ type = str , choices = ["hellaswag " , "lambada_openai " , "piqa" , "winogrande " , "copa " ,
69
+ "rte " , "openbookqa " , "lambada_standard" , "wikitext" ],
63
70
help = "tasks list for accuracy validation" )
64
71
parser .add_argument ("--limit" , default = None , type = int ,
65
72
help = "the sample num of evaluation." )
66
73
parser .add_argument ("--max_new_tokens" , default = 100 , type = int ,
67
74
help = "calibration iters." )
68
75
parser .add_argument ('--buckets' , type = int , nargs = '+' , \
69
- help = "Input length buckets to use with static_shapes" , default = [129 ])
76
+ help = "Input length buckets to use with static_shapes" , default = [256 , 512 ])
70
77
parser .add_argument ("--local_rank" ,
71
78
type = int ,
72
79
default = - 1 ,
@@ -78,67 +85,65 @@ def itrex_bootstrap_stderr(f, xs, iters):
78
85
world_size = int (os .getenv ('WORLD_SIZE' , '1' ))
79
86
local_rank = int (os .getenv ('LOCAL_RANK' , '-1' ))
80
87
81
- #if local_rank == 0:
82
- # os.environ["ENABLE_CONSOLE"] = 'True'
83
- # os.environ["LOG_LEVEL_ALL"] = '0'
84
88
85
- # model
89
+ model_dtype = torch . float32
86
90
if re .search ("llama" , args .model .lower ()) or re .search ("bloom" , args .model .lower ()):
87
91
from transformers import AutoConfig , AutoModelForCausalLM , AutoTokenizer
88
- torch .device ('hpu' )
89
- config = AutoConfig .from_pretrained (args .model )
90
92
if world_size > 1 :
91
- model_dtype = torch .bfloat16
93
+ config = AutoConfig .from_pretrained (args .model )
94
+ model_dtype = torch .bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype
92
95
deepspeed .init_distributed (dist_backend = "hccl" )
93
96
with deepspeed .OnDevice (dtype = model_dtype , device = "meta" ):
94
97
user_model = AutoModelForCausalLM .from_config (config , torch_dtype = model_dtype )
95
98
import tempfile
96
99
checkpoints_json = tempfile .NamedTemporaryFile (suffix = ".json" , mode = "+w" )
97
- from utils import write_checkpoints_json
100
+ from optimum . habana . checkpoint_utils import write_checkpoints_json # in optimum-habana
98
101
write_checkpoints_json (
99
- args .model ,
100
- local_rank ,
101
- checkpoints_json ,
102
- token = None ,
102
+ args .model ,
103
+ local_rank ,
104
+ checkpoints_json ,
105
+ token = None ,
103
106
)
104
- elif re .search ("llama" , args .model .lower ()):
105
- from models .modeling_llama import LlamaForCausalLM
106
- user_model = LlamaForCausalLM .from_pretrained (
107
+ else :
108
+ if args .load :
109
+ config = AutoConfig .from_pretrained (args .model )
110
+ with init_empty_weights ():
111
+ user_model = AutoModelForCausalLM .from_config (config , torch_dtype = model_dtype )
112
+ else :
113
+ user_model = AutoModelForCausalLM .from_pretrained (
114
+ args .model ,
115
+ device_map = 'hpu' ,
116
+ torch_dtype = model_dtype ,
117
+ )
118
+ elif re .search ("chatglm" , args .model .lower ()):
119
+ if args .load :
120
+ config = AutoConfig .from_pretrained (args .model , torch_dtype = model_dtype )
121
+ with init_empty_weights ():
122
+ user_model = AutoModelForCausalLM .from_config (config )
123
+ else :
124
+ from models .modeling_chatglm import ChatGLMForConditionalGeneration
125
+ user_model = ChatGLMForConditionalGeneration .from_pretrained (
107
126
args .model ,
127
+ revision = args .revision ,
108
128
device_map = 'hpu' ,
129
+ torch_dtype = model_dtype ,
109
130
)
131
+ # print(user_model.transformer.output_layer.weight.dtype) # always fp16
132
+ user_model .float () # static fp8 need float32 for graph compiler
133
+ else :
134
+ if args .load :
135
+ config = AutoConfig .from_pretrained (args .model )
136
+ with init_empty_weights ():
137
+ user_model = AutoModelForCausalLM .from_config (config , torch_dtype = model_dtype )
110
138
else :
111
139
user_model = AutoModelForCausalLM .from_pretrained (
112
140
args .model ,
141
+ trust_remote_code = args .trust_remote_code ,
142
+ revision = args .revision ,
113
143
device_map = 'hpu' ,
144
+ torch_dtype = model_dtype ,
114
145
)
115
- elif re .search ("chatglm" , args .model .lower ()):
116
- from models .modeling_chatglm import ChatGLMForConditionalGeneration
117
- user_model = ChatGLMForConditionalGeneration .from_pretrained (
118
- args .model ,
119
- revision = args .revision ,
120
- device_map = 'hpu' ,
121
- )
122
- else :
123
- user_model = AutoModelForCausalLM .from_pretrained (
124
- args .model ,
125
- trust_remote_code = args .trust_remote_code ,
126
- revision = args .revision ,
127
- device_map = 'hpu' ,
128
- )
129
146
130
- # tokenizer
131
- if re .search ("baichuan" , args .model .lower ()):
132
- from models .tokenization_baichuan import BaichuanTokenizer
133
- tokenizer = BaichuanTokenizer .from_pretrained (
134
- args .model ,
135
- trust_remote_code = args .trust_remote_code
136
- )
137
- else :
138
- tokenizer = AutoTokenizer .from_pretrained (
139
- args .model ,
140
- trust_remote_code = args .trust_remote_code
141
- )
142
147
143
148
if world_size > 1 :
144
149
if re .search ("llama" , args .model .lower ()):
@@ -148,36 +153,44 @@ def itrex_bootstrap_stderr(f, xs, iters):
148
153
from transformers .models .llama .modeling_llama import LlamaDecoderLayer
149
154
ds_inference_kwargs ["injection_policy" ] = {LlamaDecoderLayer : ("self_attn.o_proj" , "mlp.down_proj" )}
150
155
ds_inference_kwargs ["checkpoint" ] = checkpoints_json .name
151
-
152
156
ds_model = deepspeed .init_inference (user_model , ** ds_inference_kwargs )
153
157
else :
154
158
ds_model = deepspeed .init_inference (user_model ,
155
159
mp_size = world_size ,
156
160
replace_with_kernel_inject = False )
157
161
user_model = ds_model .module
158
162
163
+
164
+ # tokenizer
165
+ if re .search ("baichuan" , args .model .lower ()):
166
+ from models .tokenization_baichuan import BaichuanTokenizer
167
+ tokenizer = BaichuanTokenizer .from_pretrained (
168
+ args .model ,
169
+ trust_remote_code = args .trust_remote_code
170
+ )
171
+ else :
172
+ tokenizer = AutoTokenizer .from_pretrained (
173
+ args .model ,
174
+ trust_remote_code = args .trust_remote_code
175
+ )
176
+
177
+
159
178
user_model .eval ()
160
179
161
- if args .approach in ["dynamic" , "static" ]:
180
+
181
+ ### dynamic & static quantization ###
182
+ if args .approach in ["dynamic" , "static" ] and not args .load :
162
183
print ("device:" , next (user_model .parameters ()).device )
163
- from neural_compressor .torch .quantization .config import FP8QConfig , get_default_fp8_qconfig
164
- from neural_compressor .torch .algorithms .habana_fp8 import quantize_dynamic
184
+ from neural_compressor .torch .quantization .config import FP8Config , get_default_fp8_config
165
185
from neural_compressor .torch .quantization import quantize
166
- if args .precision == "fp8_e4m3" :
167
- dtype = torch .float8_e4m3fn
168
- else :
169
- dtype = torch .float8_e5m2
186
+ dtype = args .precision
170
187
if args .approach == "dynamic" :
171
- #user_model = quantize_dynamic(user_model, dtype, inplace=True)
172
- qconfig = FP8QConfig (weight_dtype = dtype , act_dtype = dtype , approach = "dynamic" )
173
- if args .skip_lm_head :
174
- fp32_config = FP8QConfig (weight_dtype = torch .float32 , act_dtype = torch .float32 )
175
- qconfig .set_local ("lm_head" , fp32_config )
176
- user_model = quantize_dynamic (user_model , qconfig , inplace = True )
188
+ from neural_compressor .torch .algorithms .habana_fp8 import quantize_dynamic
189
+ user_model = quantize_dynamic (user_model , dtype , inplace = True )
177
190
elif args .approach == "static" :
178
- qconfig = FP8QConfig ( weight_dtype = dtype , act_dtype = dtype , approach = "static" )
191
+ qconfig = FP8Config ( w_dtype = dtype , act_dtype = dtype , approach = "static" )
179
192
if args .skip_lm_head :
180
- fp32_config = FP8QConfig ( weight_dtype = torch . float32 , act_dtype = torch . float32 )
193
+ fp32_config = FP8Config ( w_dtype = "fp32" , act_dtype = "fp32" )
181
194
qconfig .set_local ("lm_head" , fp32_config )
182
195
# dataset
183
196
from datasets import load_dataset
@@ -186,7 +199,13 @@ def itrex_bootstrap_stderr(f, xs, iters):
186
199
calib_data = []
187
200
for examples in calib_dataset :
188
201
calib_data .append (
189
- tokenizer (examples ["text" ], return_tensors = "pt" , max_length = 128 )
202
+ tokenizer (
203
+ examples ["text" ],
204
+ return_tensors = "pt" ,
205
+ max_length = 64 ,
206
+ padding = "max_length" ,
207
+ truncation = True
208
+ )
190
209
)
191
210
192
211
def calib_func (model ):
@@ -199,12 +218,46 @@ def calib_func(model):
199
218
)
200
219
201
220
user_model = quantize (user_model , qconfig , calib_func , inplace = True )
202
- print (user_model , flush = True )
221
+ # saving
222
+ if args .save and local_rank in [- 1 , 0 ]:
223
+ user_model .save ("saved_results" )
224
+
225
+
226
+ if args .load :
227
+ from neural_compressor .torch .quantization import load
228
+ user_model = load (user_model , "saved_results" )
229
+
230
+
231
+ if args .approach in ["dynamic" , "static" ] or args .load :
232
+ # It enables weights constant folding
233
+ from habana_frameworks .torch .core .quantization import _check_params_as_const , _mark_params_as_const
234
+ _mark_params_as_const (user_model ) # can reduce memory allocated and speed up
235
+ _check_params_as_const (user_model )
236
+
237
+
238
+
239
+ # If torch.matmul and torch.bmm are not replaced by INC module,
240
+ # Below codes can make torch.matmul and torch.bmm run on fp8 by injection.
241
+ if not args .skip_fp8_mm and args .precision in ['fp8_e4m3' , 'fp8_e5m2' ]:
242
+ def replace_torch_mm_bmm ():
243
+ from neural_compressor .torch .amp .fp8 .functions import fp8_matmul
244
+ torch .matmul = fp8_matmul
245
+ torch .bmm = fp8_matmul
203
246
247
+ replace_torch_mm_bmm ()
248
+
249
+
250
+ # inference optimization
204
251
if args .to_graph :
205
252
import habana_frameworks .torch .hpu .graphs as htgraphs
206
253
user_model = htgraphs .wrap_in_hpu_graph (user_model )
207
254
255
+
256
+ # dump message of HPU after quantization or reloading
257
+ show_msg ()
258
+
259
+
260
+ ### generation, performance and accuracy validation ###
208
261
if args .generate :
209
262
input_prompt = "Here is my prompt"
210
263
print ("Prompt sentence:" , input_prompt )
@@ -234,6 +287,7 @@ def calib_func(model):
234
287
print ("Generated sentence:" , output_sentence )
235
288
print ("Duration:" , eval_end - eval_start )
236
289
290
+
237
291
if args .performance :
238
292
eval_start = time .perf_counter ()
239
293
input_prompt = "Intel is a company which"
@@ -242,6 +296,7 @@ def calib_func(model):
242
296
outputs = user_model .generate (input_tokens , ** generation_config )
243
297
print ("Duration of generating 100 tokens :" , time .perf_counter () - eval_start )
244
298
299
+
245
300
if args .accuracy :
246
301
247
302
class HabanaModelAdapter (lm_eval .base .BaseLM ):
@@ -292,16 +347,14 @@ def find_bucket(self, length):
292
347
return [b for b in self .buckets if b >= length ][0 ]
293
348
294
349
def _model_call (self , inps ):
295
- #print(inps.shape)
296
350
seq_length = inps .shape [- 1 ]
351
+ padding_length = 0
297
352
bucket_length = self .find_bucket (seq_length )
298
353
padding_length = bucket_length - seq_length
299
- if True :
300
- import torch .nn .functional as F
301
- inps = F .pad (inps , (0 , padding_length ), value = self .model .config .pad_token_id )
354
+ inps = F .pad (inps , (0 , padding_length ), value = self .model .config .pad_token_id )
355
+ logits = self .model (inps .to (self ._device ))["logits" ].cpu ()
302
356
303
- logits = self .model (inps .to (self ._device ))['logits' ]
304
- if True and padding_length > 0 :
357
+ if padding_length > 0 :
305
358
logits = logits [:, :- padding_length , :]
306
359
logits = logits .to (torch .float32 )
307
360
return logits
@@ -333,18 +386,18 @@ def _model_call(self, inps):
333
386
334
387
335
388
dumped = json .dumps (results , indent = 2 )
389
+ accu_dict = {}
390
+ case_name = args .approach + "-" + args .precision
336
391
for task_name in args .tasks :
337
392
if task_name == "wikitext" :
338
393
print ("Accuracy for %s is: %s" % (task_name , results ["results" ][task_name ]["word_perplexity" ]), flush = True )
394
+ accu_dict [task_name ] = [args .model , case_name , results ["results" ][task_name ]["word_perplexity" ]]
339
395
else :
340
396
print ("Accuracy for %s is: %s" % (task_name , results ["results" ][task_name ]["acc" ]), flush = True )
397
+ accu_dict [task_name ] = [args .model , case_name , results ["results" ][task_name ]["acc" ]]
398
+ if args .dump_to_excel and local_rank in [- 1 , 0 ]:
399
+ save_to_excel (accu_dict )
400
+
341
401
342
- # show memory usage
343
- mem_stats = memory_stats ()
344
- mem_dict = {
345
- "memory_allocated (GB)" : np .round (mem_stats ["InUse" ] / 1024 ** 3 , 2 ),
346
- "max_memory_allocated (GB)" : np .round (mem_stats ["MaxInUse" ] / 1024 ** 3 , 2 ),
347
- "total_memory_available (GB)" : np .round (mem_stats ["Limit" ] / 1024 ** 3 , 2 ),
348
- }
349
- for k , v in mem_dict .items ():
350
- print ("{:35} = {} GB" .format (k [:- 5 ].replace ("_" , " " ).capitalize (), v ))
402
+ # dump final message of HPU
403
+ show_msg ()
0 commit comments