|
25 | 25 | from typing import Optional
|
26 | 26 |
|
27 | 27 | import datasets
|
| 28 | +import evaluate |
28 | 29 | import numpy as np
|
29 | 30 | import torch
|
30 | 31 | import transformers
|
| 32 | +from accelerate import Accelerator |
31 | 33 | from datasets import load_dataset
|
| 34 | +from neural_compressor import DistillationConfig, QuantizationAwareTrainingConfig, WeightPruningConfig |
32 | 35 | from torch.utils.data.dataloader import DataLoader
|
33 | 36 | from tqdm.auto import tqdm
|
| 37 | +from trainer_qa import QuestionAnsweringINCTrainer |
34 | 38 | from transformers import (
|
35 | 39 | AutoConfig,
|
36 | 40 | AutoModelForQuestionAnswering,
|
37 | 41 | AutoTokenizer,
|
38 | 42 | DataCollatorWithPadding,
|
39 | 43 | EvalPrediction,
|
40 | 44 | HfArgumentParser,
|
41 |
| - PreTrainedModel, |
42 | 45 | PreTrainedTokenizerFast,
|
43 | 46 | TrainingArguments,
|
44 | 47 | default_data_collator,
|
|
47 | 50 | from transformers.trainer_utils import get_last_checkpoint
|
48 | 51 | from transformers.utils import check_min_version
|
49 | 52 | from transformers.utils.versions import require_version
|
| 53 | +from utils_qa import postprocess_qa_predictions |
50 | 54 |
|
51 |
| -import evaluate |
52 |
| -from accelerate import Accelerator |
53 |
| -from neural_compressor import DistillationConfig, QuantizationAwareTrainingConfig, WeightPruningConfig |
54 | 55 | from optimum.intel.neural_compressor import INCModelForQuestionAnswering
|
55 |
| -from trainer_qa import QuestionAnsweringINCTrainer |
56 |
| -from utils_qa import postprocess_qa_predictions |
57 | 56 |
|
58 | 57 |
|
59 | 58 | # Will be removed when neural-compressor next release is out
|
@@ -553,7 +552,10 @@ def move_input_to_device(input, device):
|
553 | 552 | )
|
554 | 553 | teacher_model_qa = QAModel(teacher_model)
|
555 | 554 | teacher_model_qa = accelerator.prepare(teacher_model_qa)
|
556 |
| - num_param = lambda model: sum(p.numel() for p in model.parameters()) |
| 555 | + |
| 556 | + def num_param(model): |
| 557 | + return sum(p.numel() for p in model.parameters()) |
| 558 | + |
557 | 559 | logger.info(
|
558 | 560 | "***** Number of teacher model parameters: {:.2f}M *****".format(num_param(teacher_model_qa) / 10**6)
|
559 | 561 | )
|
@@ -662,9 +664,33 @@ def prepare_validation_features(examples):
|
662 | 664 | load_from_cache_file=not data_args.overwrite_cache,
|
663 | 665 | desc="Running tokenizer on validation dataset",
|
664 | 666 | )
|
| 667 | + |
665 | 668 | if data_args.max_eval_samples is not None:
|
666 | 669 | # During Feature creation dataset samples might increase, we will select required samples again
|
667 |
| - eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) |
| 670 | + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) |
| 671 | + eval_dataset = eval_dataset.select(range(max_eval_samples)) |
| 672 | + |
| 673 | + if training_args.do_predict: |
| 674 | + if "test" not in raw_datasets: |
| 675 | + raise ValueError("--do_predict requires a test dataset") |
| 676 | + predict_examples = raw_datasets["test"] |
| 677 | + if data_args.max_predict_samples is not None: |
| 678 | + # We will select sample from whole data |
| 679 | + predict_examples = predict_examples.select(range(data_args.max_predict_samples)) |
| 680 | + # Predict Feature Creation |
| 681 | + with training_args.main_process_first(desc="prediction dataset map pre-processing"): |
| 682 | + predict_dataset = predict_examples.map( |
| 683 | + prepare_validation_features, |
| 684 | + batched=True, |
| 685 | + num_proc=data_args.preprocessing_num_workers, |
| 686 | + remove_columns=column_names, |
| 687 | + load_from_cache_file=not data_args.overwrite_cache, |
| 688 | + desc="Running tokenizer on prediction dataset", |
| 689 | + ) |
| 690 | + if data_args.max_predict_samples is not None: |
| 691 | + # During Feature creation dataset samples might increase, we will select required samples again |
| 692 | + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) |
| 693 | + predict_dataset = predict_dataset.select(range(max_predict_samples)) |
668 | 694 |
|
669 | 695 | # Post-processing:
|
670 | 696 | def post_processing_function(examples, features, predictions, stage="eval"):
|
|
0 commit comments