diff --git a/.gitignore b/.gitignore index 3bb3247..2e48ba4 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ venv3/ logs/ results/ .ipynb_checkpoints/ + +kernels/int_quantization* +kernels/build/ +kernels/dist/ \ No newline at end of file diff --git a/inference/inference_sim.py b/inference/inference_sim.py index 858af71..ac952e7 100644 --- a/inference/inference_sim.py +++ b/inference/inference_sim.py @@ -1,7 +1,14 @@ import os, sys -dir_path = os.path.dirname(os.path.realpath(__file__)) -root_dir = os.path.join(dir_path, os.path.pardir) -sys.path.append(root_dir) + +# dir_path = os.path.dirname(os.path.realpath(__file__)) +# root_dir = os.path.join(dir_path, os.path.pardir) +# sys.path.append(root_dir) + +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir)) +sys.path.insert(0, parent_dir) + + import argparse import time import logging @@ -127,8 +134,13 @@ torch.manual_seed(12345) + class InferenceModel: def __init__(self, ml_logger=None): + + + self.onnx_save = True + self.ml_logger = ml_logger global args, best_prec1 @@ -229,6 +241,7 @@ def __init__(self, ml_logger=None): num_workers=args.workers, pin_memory=True) def run(self): + if args.eval_precision: elog = EvalLog(['dtype', 'val_prec1', 'val_prec5']) print("\nFloat32 no quantization") @@ -274,8 +287,8 @@ def run(self): return val_loss, val_prec1, val_prec5 - def validate(val_loader, model, criterion): + onnx_save = True batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() @@ -304,6 +317,14 @@ def validate(val_loader, model, criterion): QM().verbose = True input = input.to(args.device) target = target.to(args.device) + if i == 0 and onnx_save == True: + onnx_save = False + quantized_model_path = 'quantized_model.pth' + quantized_model_path_onnx = 'quantized_model.onnx' + torch.onnx.export(model, input, quantized_model_path_onnx) + torch.save(model.state_dict(), quantized_model_path) + print(f"Quantized model saved to {quantized_model_path}") + if args.dump_dir is not None and i == 5: with DM(args.dump_dir): DM().set_tag('batch%d'%i)