|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +import tqdm |
| 4 | + |
| 5 | +# ensure that unnecessary memory is released during quantization. |
| 6 | +os.environ.setdefault("PT_HPU_WEIGHT_SHARING", "0") |
| 7 | +if int(os.getenv("WORLD_SIZE", "0")) > 0: |
| 8 | + os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0") |
| 9 | + os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "true") |
| 10 | + |
| 11 | + |
| 12 | +import torch |
| 13 | +import habana_frameworks.torch.core as htcore |
| 14 | + |
| 15 | +from neural_compressor.torch.quantization import ( |
| 16 | + FP8Config, |
| 17 | + prepare, |
| 18 | + convert, |
| 19 | + finalize_calibration, |
| 20 | + save, |
| 21 | + load, |
| 22 | +) |
| 23 | +from neural_compressor.torch.utils import get_used_hpu_mem_MB, get_used_cpu_mem_MB, logger, forward_wrapper |
| 24 | +from neural_compressor.torch.utils.block_wise import block_wise_calibration |
| 25 | +from neural_compressor.torch.utils.llm_utility import ( |
| 26 | + initialize_model_and_tokenizer, |
| 27 | + get_default_llm_dataloader, |
| 28 | + llm_benchmark, |
| 29 | +) |
| 30 | + |
| 31 | +# use no_grad mode for quantization |
| 32 | +torch.set_grad_enabled(False) |
| 33 | +htcore.hpu_set_env() |
| 34 | +hpu_mem_0 = get_used_hpu_mem_MB() |
| 35 | +cpu_mem_0 = get_used_cpu_mem_MB() |
| 36 | + |
| 37 | + |
| 38 | +if __name__ == "__main__": |
| 39 | + parser = argparse.ArgumentParser( |
| 40 | + description="Habana FP8 quantization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| 41 | + ) |
| 42 | + parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Meta-Llama-3.1-405B", help="model name or path") |
| 43 | + parser.add_argument("--quantize", action="store_true", help="whether to quantize model") |
| 44 | + parser.add_argument("--scale_method", type=str, default="maxabs_hw", help="Choose scale method", choices=[ |
| 45 | + # per-tensor |
| 46 | + "unit_scale", "hw_aligned_single_scale", "maxabs_hw", "maxabs_pow2", |
| 47 | + "maxabs_arbitrary", "maxabs_hw_opt_weight", "maxabs_pow2_opt_weight", |
| 48 | + # per-channel |
| 49 | + "act_maxabs_hw_weights_pcs_maxabs_pow2", "act_maxabs_hw_weights_pcs_opt_pow2", |
| 50 | + "act_maxabs_pow2_weights_pcs_maxabs_pow2", "act_maxabs_pow2_weights_pcs_opt_pow2", |
| 51 | + ]) |
| 52 | + parser.add_argument("--use_hpu_graph", action="store_true", help="whether to use hpu graph mode to accelerate performance") |
| 53 | + parser.add_argument("--enable_block_wise_calibration", action="store_true", help="whether to use block-wise calibration") |
| 54 | + parser.add_argument("--disable_optimum_habana", action="store_true", help="whether to use adapt_transformers_to_gaudi") |
| 55 | + parser.add_argument("--save", action="store_true", help="whether to save the quantized model") |
| 56 | + parser.add_argument("--load", action="store_true", help="whether to load the quantized model") |
| 57 | + parser.add_argument("--save_path", type=str, default="saved_results", help="path to save the quantized model") |
| 58 | + parser.add_argument("--accuracy", action="store_true", help="accuracy measurement") |
| 59 | + parser.add_argument("--performance", action="store_true", help="performance measurement") |
| 60 | + parser.add_argument("--local_rank", type=int, default=0, metavar="N", help="Local process rank.") |
| 61 | + parser.add_argument("--batch_size", default=1, type=int, help="batch size for accuracy measurement.") |
| 62 | + parser.add_argument("--num_fewshot", default=0, type=int, help="num_fewshot of lm_eval.") |
| 63 | + parser.add_argument("--dump_stats_path", type=str, default="./hqt_output/measure", help="path and prefix to calibration info file.") |
| 64 | + parser.add_argument("--tasks", default="lambada_openai", |
| 65 | + type=str, help="tasks for accuracy validation, text-generation and code-generation tasks are different.") |
| 66 | + parser.add_argument("--dataset_name", type=str, default="NeelNanda/pile-10k", help="dataset name for calibration dataloader") |
| 67 | + parser.add_argument("--nsamples", type=int, default=128, help="number of samples for calibration dataloader") |
| 68 | + parser.add_argument("--seq_len", type=int, default=128, help="sequence length for calibration dataloader and benchmarking") |
| 69 | + args = parser.parse_args() |
| 70 | + if not args.disable_optimum_habana: |
| 71 | + # Tweak generation so that it runs faster on Gaudi |
| 72 | + import transformers |
| 73 | + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi |
| 74 | + if args.quantize: |
| 75 | + orig_check_support_param_buffer_assignment = transformers.modeling_utils.check_support_param_buffer_assignment |
| 76 | + adapt_transformers_to_gaudi() |
| 77 | + # to protect memory mapping usage for quantization |
| 78 | + transformers.modeling_utils.check_support_param_buffer_assignment = orig_check_support_param_buffer_assignment |
| 79 | + else: |
| 80 | + adapt_transformers_to_gaudi() |
| 81 | + |
| 82 | + model, tokenizer = initialize_model_and_tokenizer(args.model_name_or_path, use_load=args.load, device="hpu") |
| 83 | + # show used memory |
| 84 | + logger.info(f"After loading model, used HPU memory: {round((get_used_hpu_mem_MB() - hpu_mem_0)/1024, 3)} GiB") |
| 85 | + logger.info(f"After loading model, used CPU memory: {round((get_used_cpu_mem_MB() - cpu_mem_0)/1024, 3)} GiB") |
| 86 | + |
| 87 | + if args.quantize: |
| 88 | + if args.enable_block_wise_calibration: |
| 89 | + logger.warning("Block-wise calibration is enabled, lm_head will be excluded from calibration.") |
| 90 | + |
| 91 | + # prepare |
| 92 | + qconfig = FP8Config( |
| 93 | + fp8_config="E4M3", |
| 94 | + scale_method=args.scale_method, |
| 95 | + blocklist={"names": ["lm_head"]} if args.enable_block_wise_calibration else {}, # block-wise cannot calibrate lm_head |
| 96 | + measure_on_hpu=False if args.enable_block_wise_calibration else True, # to avoid device mapping of model |
| 97 | + dump_stats_path=args.dump_stats_path, |
| 98 | + ) |
| 99 | + if args.scale_method in ["unit_scale", "hw_aligned_single_scale"]: |
| 100 | + model = convert(model, qconfig) |
| 101 | + else: |
| 102 | + model = prepare(model, qconfig) |
| 103 | + |
| 104 | + # calibration |
| 105 | + dataloader = get_default_llm_dataloader( |
| 106 | + tokenizer, |
| 107 | + dataset_name=args.dataset_name, |
| 108 | + bs=args.batch_size, |
| 109 | + nsamples=args.nsamples, |
| 110 | + seq_len=args.seq_len, |
| 111 | + seed=42, |
| 112 | + ) |
| 113 | + if args.enable_block_wise_calibration: |
| 114 | + block_wise_calibration(model, dataloader) |
| 115 | + else: |
| 116 | + if args.use_hpu_graph: |
| 117 | + from habana_frameworks.torch.hpu import wrap_in_hpu_graph |
| 118 | + model = wrap_in_hpu_graph(model) |
| 119 | + for data in tqdm.tqdm(dataloader): |
| 120 | + logger.info("Calibration started") |
| 121 | + forward_wrapper(model, data) |
| 122 | + logger.info("Calibration end") |
| 123 | + |
| 124 | + # convert |
| 125 | + model = convert(model) |
| 126 | + |
| 127 | + # show used memory |
| 128 | + logger.info(f"Used HPU memory: {round((get_used_hpu_mem_MB() - hpu_mem_0)/1024, 3)} GiB") |
| 129 | + logger.info(f"Used CPU memory: {round((get_used_cpu_mem_MB() - cpu_mem_0)/1024, 3)} GiB") |
| 130 | + if args.save: |
| 131 | + logger.info(f"Saving quantized model to {args.save_path}") |
| 132 | + save(model, args.save_path, format="huggingface") |
| 133 | + tokenizer.save_pretrained(args.save_path) |
| 134 | + logger.info(f"Saved quantized model to {args.save_path}") |
| 135 | + exit(0) # model is wrapped during calibration, need to exit before accuracy and performance measurement |
| 136 | + |
| 137 | + # preprocess model for accuracy and performance measurement |
| 138 | + if not args.load: |
| 139 | + # compare fp8 with bf16, not fp32. |
| 140 | + model = model.to(torch.bfloat16) |
| 141 | + model = model.eval().to("hpu") |
| 142 | + if args.use_hpu_graph: |
| 143 | + from habana_frameworks.torch.hpu import wrap_in_hpu_graph |
| 144 | + model = wrap_in_hpu_graph(model) |
| 145 | + htcore.hpu_inference_initialize(model, mark_only_scales_as_const=True) |
| 146 | + |
| 147 | + if args.accuracy: |
| 148 | + from neural_compressor.evaluation.lm_eval import evaluate, LMEvalParser |
| 149 | + eval_args = LMEvalParser( |
| 150 | + model="hf", |
| 151 | + user_model=model, |
| 152 | + tokenizer=tokenizer, |
| 153 | + batch_size=args.batch_size, |
| 154 | + tasks=args.tasks, |
| 155 | + device="hpu", |
| 156 | + pad_to_buckets=True, |
| 157 | + num_fewshot=args.num_fewshot, |
| 158 | + ) |
| 159 | + results = evaluate(eval_args) |
| 160 | + # show used memory |
| 161 | + logger.info(f"Used HPU memory: {round((get_used_hpu_mem_MB() - hpu_mem_0)/1024, 3)} GiB") |
| 162 | + logger.info(f"Used CPU memory: {round((get_used_cpu_mem_MB() - cpu_mem_0)/1024, 3)} GiB") |
| 163 | + |
| 164 | + |
| 165 | + if args.performance: |
| 166 | + llm_benchmark(model, args.batch_size, args.seq_len) |
| 167 | + # show used memory |
| 168 | + logger.info(f"Used HPU memory: {round((get_used_hpu_mem_MB() - hpu_mem_0)/1024, 3)} GiB") |
| 169 | + logger.info(f"Used CPU memory: {round((get_used_cpu_mem_MB() - cpu_mem_0)/1024, 3)} GiB") |
0 commit comments