|
24 | 24 | import sys
|
25 | 25 | from dataclasses import dataclass, field
|
26 | 26 | from functools import partial
|
27 |
| -from pathlib import Path |
28 | 27 | from typing import Optional
|
29 | 28 |
|
30 | 29 | import datasets
|
31 | 30 | import transformers
|
32 | 31 | from datasets import load_dataset
|
| 32 | +from evaluate import load |
33 | 33 | from transformers import AutoTokenizer, EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TrainingArguments
|
34 | 34 | from transformers.utils import check_min_version
|
35 | 35 | from transformers.utils.versions import require_version
|
| 36 | +from utils_qa import postprocess_qa_predictions |
36 | 37 |
|
37 |
| -from evaluate import load |
38 | 38 | from optimum.onnxruntime import ORTModelForQuestionAnswering, ORTOptimizer
|
39 |
| -from optimum.onnxruntime.configuration import OptimizationConfig, ORTConfig |
| 39 | +from optimum.onnxruntime.configuration import OptimizationConfig |
40 | 40 | from optimum.onnxruntime.model import ORTModel
|
41 |
| -from trainer_qa import QuestionAnsweringTrainer |
42 |
| -from utils_qa import postprocess_qa_predictions |
43 | 41 |
|
44 | 42 |
|
45 | 43 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
@@ -307,7 +305,6 @@ def main():
|
307 | 305 | )
|
308 | 306 |
|
309 | 307 | os.makedirs(training_args.output_dir, exist_ok=True)
|
310 |
| - model_path = os.path.join(training_args.output_dir, "model.onnx") |
311 | 308 | optimized_model_path = os.path.join(training_args.output_dir, "model_optimized.onnx")
|
312 | 309 |
|
313 | 310 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name or model_args.model_name_or_path)
|
@@ -492,7 +489,7 @@ def compute_metrics(p: EvalPrediction):
|
492 | 489 | metrics = compute_metrics(predictions)
|
493 | 490 |
|
494 | 491 | # Save metrics
|
495 |
| - with open(os.path.join(training_args.output_dir, f"eval_results.json"), "w") as f: |
| 492 | + with open(os.path.join(training_args.output_dir, "eval_results.json"), "w") as f: |
496 | 493 | json.dump(metrics, f, indent=4, sort_keys=True)
|
497 | 494 |
|
498 | 495 | # Prediction
|
@@ -527,7 +524,7 @@ def compute_metrics(p: EvalPrediction):
|
527 | 524 | metrics = compute_metrics(predictions)
|
528 | 525 |
|
529 | 526 | # Save metrics
|
530 |
| - with open(os.path.join(training_args.output_dir, f"predict_results.json"), "w") as f: |
| 527 | + with open(os.path.join(training_args.output_dir, "predict_results.json"), "w") as f: |
531 | 528 | json.dump(metrics, f, indent=4, sort_keys=True)
|
532 | 529 |
|
533 | 530 |
|
|
0 commit comments