|
15 | 15 | import copy
|
16 | 16 | import inspect
|
17 | 17 | import logging
|
| 18 | +import warnings |
18 | 19 | from enum import Enum
|
19 | 20 | from itertools import chain
|
20 | 21 | from pathlib import Path
|
|
30 | 31 | from neural_compressor.quantization import fit
|
31 | 32 | from torch.utils.data import DataLoader, RandomSampler
|
32 | 33 | from transformers import (
|
| 34 | + AutoModelForCausalLM, |
| 35 | + AutoModelForMaskedLM, |
| 36 | + AutoModelForMultipleChoice, |
| 37 | + AutoModelForQuestionAnswering, |
| 38 | + AutoModelForSeq2SeqLM, |
| 39 | + AutoModelForSequenceClassification, |
| 40 | + AutoModelForTokenClassification, |
| 41 | + AutoModelForVision2Seq, |
33 | 42 | DataCollator,
|
34 | 43 | PretrainedConfig,
|
35 | 44 | PreTrainedModel,
|
| 45 | + XLNetLMHeadModel, |
36 | 46 | default_data_collator,
|
37 | 47 | )
|
38 | 48 |
|
39 | 49 | from optimum.exporters import TasksManager
|
40 | 50 | from optimum.exporters.onnx import OnnxConfig
|
41 | 51 | from optimum.onnxruntime import ORTModel
|
42 |
| -from optimum.onnxruntime.modeling_decoder import ORTModelDecoder |
| 52 | +from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM |
43 | 53 | from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration
|
44 | 54 | from optimum.onnxruntime.utils import ONNX_DECODER_NAME
|
45 | 55 | from optimum.quantization_base import OptimumQuantizer
|
@@ -256,7 +266,7 @@ def quantize(
|
256 | 266 | if isinstance(self._original_model, ORTModelForConditionalGeneration):
|
257 | 267 | raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization")
|
258 | 268 |
|
259 |
| - if isinstance(self._original_model, ORTModelDecoder): |
| 269 | + if isinstance(self._original_model, ORTModelForCausalLM): |
260 | 270 | model_or_path = self._original_model.onnx_paths
|
261 | 271 | if len(model_or_path) > 1:
|
262 | 272 | raise RuntimeError(
|
@@ -528,3 +538,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t
|
528 | 538 | q_model = convert(q_model, mapping=q_mapping, inplace=True)
|
529 | 539 |
|
530 | 540 | return q_model
|
| 541 | + |
| 542 | + |
| 543 | +class IncQuantizedModel(INCModel): |
| 544 | + @classmethod |
| 545 | + def from_pretrained(cls, *args, **kwargs): |
| 546 | + warnings.warn( |
| 547 | + f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use " |
| 548 | + f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead." |
| 549 | + ) |
| 550 | + return super().from_pretrained(*args, **kwargs) |
| 551 | + |
| 552 | + |
| 553 | +class IncQuantizedModelForQuestionAnswering(IncQuantizedModel): |
| 554 | + auto_model_class = AutoModelForQuestionAnswering |
| 555 | + |
| 556 | + |
| 557 | +class IncQuantizedModelForSequenceClassification(IncQuantizedModel): |
| 558 | + auto_model_class = AutoModelForSequenceClassification |
| 559 | + |
| 560 | + |
| 561 | +class IncQuantizedModelForTokenClassification(IncQuantizedModel): |
| 562 | + auto_model_class = AutoModelForTokenClassification |
| 563 | + |
| 564 | + |
| 565 | +class IncQuantizedModelForMultipleChoice(IncQuantizedModel): |
| 566 | + auto_model_class = AutoModelForMultipleChoice |
| 567 | + |
| 568 | + |
| 569 | +class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel): |
| 570 | + auto_model_class = AutoModelForSeq2SeqLM |
| 571 | + |
| 572 | + |
| 573 | +class IncQuantizedModelForCausalLM(IncQuantizedModel): |
| 574 | + auto_model_class = AutoModelForCausalLM |
| 575 | + |
| 576 | + |
| 577 | +class IncQuantizedModelForMaskedLM(IncQuantizedModel): |
| 578 | + auto_model_class = AutoModelForMaskedLM |
| 579 | + |
| 580 | + |
| 581 | +class IncQuantizedModelForXLNetLM(IncQuantizedModel): |
| 582 | + auto_model_class = XLNetLMHeadModel |
| 583 | + |
| 584 | + |
| 585 | +class IncQuantizedModelForVision2Seq(IncQuantizedModel): |
| 586 | + auto_model_class = AutoModelForVision2Seq |
0 commit comments