Skip to content

Commit e87c95f

Browse files
Kaihui-intelpre-commit-ci[bot]
andauthoredApr 23, 2024
Fix weight_only algorithms import (#1742)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0ba5732 commit e87c95f

File tree

7 files changed

+31
-34
lines changed

7 files changed

+31
-34
lines changed
 

‎examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ wandb
99
einops
1010
neural-compressor
1111
intel-extension-for-transformers
12-
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
12+
lm-eval
1313
peft

‎examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@
5050
help="Pad input ids to max length.")
5151
parser.add_argument("--calib_iters", default=512, type=int,
5252
help="calibration iters.")
53-
parser.add_argument("--tasks", nargs='+', default=["lambada_openai",
54-
"hellaswag", "winogrande", "piqa", "wikitext"],
53+
parser.add_argument("--tasks", default="lambada_openai,hellaswag,winogrande,piqa,wikitext",
5554
type=str, help="tasks list for accuracy validation")
5655
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
5756
# ============SmoothQuant configs==============
@@ -390,24 +389,27 @@ def run_fn(model):
390389

391390
if args.accuracy:
392391
user_model.eval()
393-
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate
394-
395-
results = evaluate(
396-
model="hf-causal",
392+
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
393+
eval_args = LMEvalParser(
394+
model="hf",
397395
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
398396
user_model=user_model,
397+
tokenizer = tokenizer,
399398
batch_size=args.batch_size,
400399
tasks=args.tasks,
400+
device="cpu",
401401
)
402+
results = evaluate(eval_args)
403+
402404
dumped = json.dumps(results, indent=2)
403405
if args.save_accuracy_path:
404406
with open(args.save_accuracy_path, "w") as f:
405407
f.write(dumped)
406-
for task_name in args.tasks:
408+
for task_name in args.tasks.split(","):
407409
if task_name == "wikitext":
408-
acc = results["results"][task_name]["word_perplexity"]
410+
acc = results["results"][task_name]["word_perplexity,none"]
409411
else:
410-
acc = results["results"][task_name]["acc"]
412+
acc = results["results"][task_name]["acc,none"]
411413
print("Accuracy: %.5f" % acc)
412414
print('Batch size = %d' % args.batch_size)
413415

@@ -417,21 +419,25 @@ def run_fn(model):
417419
import time
418420

419421
samples = args.iters * args.batch_size
420-
start = time.time()
421-
results = evaluate(
422-
model="hf-causal",
422+
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
423+
eval_args = LMEvalParser(
424+
model="hf",
423425
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
424426
user_model=user_model,
427+
tokenizer = tokenizer,
425428
batch_size=args.batch_size,
426429
tasks=args.tasks,
427430
limit=samples,
431+
device="cpu",
428432
)
433+
start = time.time()
434+
results = evaluate(eval_args)
429435
end = time.time()
430-
for task_name in args.tasks:
436+
for task_name in args.tasks.split(","):
431437
if task_name == "wikitext":
432-
acc = results["results"][task_name]["word_perplexity"]
438+
acc = results["results"][task_name]["word_perplexity,none"]
433439
else:
434-
acc = results["results"][task_name]["acc"]
440+
acc = results["results"][task_name]["acc,none"]
435441
print("Accuracy: %.5f" % acc)
436442
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
437443
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))

‎neural_compressor/torch/algorithms/weight_only/__init__.py

-9
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from .rtn import rtn_quantize
16-
from .gptq import gptq_quantize
17-
from .awq import awq_quantize
18-
from .teq import teq_quantize
19-
from .autoround import autoround_quantize
20-
from .hqq import hqq_quantize
21-
from .modules import WeightOnlyLinear
22-
from .utility import *

‎neural_compressor/torch/algorithms/weight_only/modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
bits = self.dtype.lstrip("int")
7070
self.dtype = "int"
7171
if "int" not in self.dtype: # for nf4, fp4
72-
from neural_compressor.torch.algorithms.weight_only import FLOAT_MAPPING, INT_MAPPING
72+
from neural_compressor.torch.algorithms.weight_only.utility import FLOAT_MAPPING, INT_MAPPING
7373

7474
self.use_optimum_format = False # optimum_format doesn't suit for symmetric nf4 fp4.
7575
float_list = FLOAT_MAPPING[self.dtype]

‎neural_compressor/torch/quantization/algorithm_entry.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def rtn_entry(
4040
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs
4141
) -> torch.nn.Module:
4242
"""The main entry to apply rtn quantization."""
43-
from neural_compressor.torch.algorithms.weight_only import rtn_quantize
43+
from neural_compressor.torch.algorithms.weight_only.rtn import rtn_quantize
4444

4545
# rebuild weight_config for rtn_quantize function
4646
weight_config = {}
@@ -75,7 +75,7 @@ def gptq_entry(
7575
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs
7676
) -> torch.nn.Module:
7777
logger.info("Quantize model with the GPTQ algorithm.")
78-
from neural_compressor.torch.algorithms.weight_only import gptq_quantize
78+
from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize
7979

8080
# rebuild weight_config for gptq_quantize function
8181
weight_config = {}
@@ -228,7 +228,7 @@ def awq_quantize_entry(
228228
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], AWQConfig], *args, **kwargs
229229
) -> torch.nn.Module:
230230
logger.info("Quantize model with the AWQ algorithm.")
231-
from neural_compressor.torch.algorithms.weight_only import awq_quantize
231+
from neural_compressor.torch.algorithms.weight_only.awq import awq_quantize
232232

233233
weight_config = {}
234234
for (op_name, op_type), op_config in configs_mapping.items():
@@ -288,7 +288,7 @@ def awq_quantize_entry(
288288
def teq_quantize_entry(
289289
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], TEQConfig], *args, **kwargs
290290
) -> torch.nn.Module:
291-
from neural_compressor.torch.algorithms.weight_only import teq_quantize
291+
from neural_compressor.torch.algorithms.weight_only.teq import teq_quantize
292292

293293
logger.info("Quantize model with the TEQ algorithm.")
294294
weight_config = {}
@@ -338,7 +338,7 @@ def teq_quantize_entry(
338338
def autoround_quantize_entry(
339339
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig], *args, **kwargs
340340
) -> torch.nn.Module:
341-
from neural_compressor.torch.algorithms.weight_only import autoround_quantize
341+
from neural_compressor.torch.algorithms.weight_only.autoround import autoround_quantize
342342

343343
logger.info("Quantize model with the AutoRound algorithm.")
344344
calib_func = kwargs.get("run_fn", None)
@@ -407,7 +407,7 @@ def autoround_quantize_entry(
407407
def hqq_entry(
408408
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, Callable], HQQConfig], *args, **kwargs
409409
) -> torch.nn.Module:
410-
from neural_compressor.torch.algorithms.weight_only import hqq_quantize
410+
from neural_compressor.torch.algorithms.weight_only.hqq import hqq_quantize
411411

412412
logger.info("Quantize model with the HQQ algorithm.")
413413
q_model = hqq_quantize(model, configs_mapping)

‎test/3x/torch/quantization/weight_only/test_gptq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import transformers
66

7-
from neural_compressor.torch.algorithms.weight_only import WeightOnlyLinear
7+
from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear
88
from neural_compressor.torch.quantization import GPTQConfig, get_default_gptq_config, get_default_rtn_config, quantize
99

1010

‎test/3x/torch/quantization/weight_only/test_rtn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import transformers
66

7-
from neural_compressor.torch.algorithms.weight_only import WeightOnlyLinear
7+
from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear
88
from neural_compressor.torch.quantization import (
99
RTNConfig,
1010
get_default_double_quant_config,

0 commit comments

Comments
 (0)
Please sign in to comment.