Skip to content

Commit 6c5fe5a

Browse files
committed
fix trainer
1 parent 1328c9e commit 6c5fe5a

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

optimum/intel/neural_compressor/trainer.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from transformers import Trainer
4040
from transformers.data.data_collator import DataCollator
4141
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
42+
from transformers.feature_extraction_utils import FeatureExtractionMixin
4243
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model
4344
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
4445
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@@ -104,7 +105,7 @@
104105
from neural_compressor.config import _BaseQuantizationConfig
105106

106107

107-
__version__ = "4.22.2"
108+
__version__ = "4.46.0"
108109

109110

110111
logger = logging.get_logger(__name__)
@@ -122,8 +123,9 @@ def __init__(
122123
data_collator: Optional[DataCollator] = None,
123124
train_dataset: Optional[Dataset] = None,
124125
eval_dataset: Optional[Dataset] = None,
125-
tokenizer: Optional[PreTrainedTokenizerBase] = None,
126+
processing_class: Optional[Union[PreTrainedTokenizerBase, FeatureExtractionMixin]] = None,
126127
model_init: Callable[[], PreTrainedModel] = None,
128+
compute_loss_func: Optional[Callable] = None,
127129
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
128130
callbacks: Optional[List[TrainerCallback]] = None,
129131
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
@@ -132,6 +134,7 @@ def __init__(
132134
pruning_config: Optional[_BaseQuantizationConfig] = None,
133135
distillation_config: Optional[_BaseQuantizationConfig] = None,
134136
task: Optional[str] = None,
137+
**kwargs,
135138
):
136139
self.neftune_noise_alpha = None
137140

@@ -141,12 +144,12 @@ def __init__(
141144
data_collator,
142145
train_dataset,
143146
eval_dataset,
144-
tokenizer,
145-
model_init,
146-
compute_metrics,
147-
callbacks,
148-
optimizers,
149-
preprocess_logits_for_metrics,
147+
processing_class or kwargs.get("tokenizer", None),
148+
model_init=model_init,
149+
compute_metrics=compute_metrics,
150+
callbacks=callbacks,
151+
optimizers=optimizers,
152+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
150153
)
151154

152155
if self.args.device.type == "cuda" and not is_neural_compressor_version(">", "2.0.0"):
@@ -766,7 +769,7 @@ def _get_logits(model_outputs):
766769
output_names = ["logits", "start_logits", "end_logits"]
767770
return tuple(model_outputs.get(name) for name in output_names if name in model_outputs)
768771

769-
def compute_loss(self, model, inputs, return_outputs=False):
772+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
770773
"""
771774
How the loss is computed by Trainer. By default, all models return the loss in the first element.
772775
"""

optimum/intel/openvino/trainer.py

+5
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ def __init__(
216216
logger.warning("OVTrainer is deprecated and will be removed in optimum-intel v1.22.0.")
217217

218218
if is_transformers_version(">=", "4.45.0"):
219+
if is_transformers_version(">=", "4.46.0"):
220+
raise ImportError(
221+
f"Unsupported transformers version found is {_transformers_version} which is not supported by the OVTrainer. Please downgrade to v4.44"
222+
)
223+
219224
logger.warning(
220225
f"The transformers version found is {_transformers_version} which is not officially supported by the OVTrainer, use at your own risk"
221226
)

tests/openvino/test_quantization.py

+3
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,9 @@ class OVTrainerTest(unittest.TestCase):
771771
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("albert", 64, 39),)
772772

773773
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
774+
@unittest.skipIf(
775+
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
776+
)
774777
def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8):
775778
model_id = MODEL_NAMES[model_name]
776779
model = AutoModelForSequenceClassification.from_pretrained(model_id, attn_implementation="eager")

tests/openvino/test_training.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,10 @@ class OVTrainerTextClassificationTrainingTest(OVTrainerBaseTrainingTest):
475475
task = "sequence-classification"
476476

477477
@parameterized.expand(OVTRAINER_TEXT_CLASSIFICATION_TEST_DESCRIPTORS.items())
478-
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
478+
@unittest.skipIf(
479+
is_transformers_version("<", "4.41") or is_transformers_version(">=", "4.46"),
480+
reason="Mismatch in expected fake quantized op and incompatible with transformers v4.46",
481+
)
479482
def test_training(self, _, desc: OVTrainerTestDescriptor):
480483
self.run_ovtrainer_training_checks(desc)
481484

@@ -627,7 +630,10 @@ class OVTrainerImageClassificationTrainingTest(OVTrainerBaseTrainingTest):
627630
@parameterized.expand(OVTRAINER_IMAGE_CLASSIFICATION_TEST_DESCRIPTORS.items())
628631
@pytest.mark.run_slow
629632
@slow
630-
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
633+
@unittest.skipIf(
634+
is_transformers_version("<", "4.41") or is_transformers_version(">=", "4.46"),
635+
reason="Mismatch in expected fake quantized op and incompatible with transformers v4.46",
636+
)
631637
def test_training(self, _, desc: OVTrainerTestDescriptor):
632638
self.run_ovtrainer_training_checks(desc)
633639

@@ -808,6 +814,9 @@ class OVTrainerAudioClassificationTrainingTest(OVTrainerBaseTrainingTest):
808814
@parameterized.expand(OVTRAINER_AUDIO_CLASSIFICATION_TEST_DESCRIPTORS.items())
809815
@pytest.mark.run_slow
810816
@slow
817+
@unittest.skipIf(
818+
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
819+
)
811820
def test_training(self, _, desc: OVTrainerTestDescriptor):
812821
self.run_ovtrainer_training_checks(desc)
813822

0 commit comments

Comments
 (0)