|
19 | 19 | import math
|
20 | 20 | import os
|
21 | 21 | import random
|
22 |
| -import tempfile |
23 | 22 | from copy import deepcopy
|
24 | 23 | from functools import partial
|
25 | 24 | from io import BytesIO
|
|
34 | 33 | import torch.utils.checkpoint
|
35 | 34 | from accelerate import Accelerator
|
36 | 35 | from accelerate.logging import get_logger
|
37 |
| -from accelerate.utils import set_seed |
| 36 | +from accelerate.utils import ProjectConfiguration, set_seed |
38 | 37 | from datasets import load_dataset
|
39 | 38 | from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, LMSDiscreteScheduler, StableDiffusionPipeline
|
40 | 39 | from diffusers.optimization import get_scheduler
|
|
44 | 43 | from nncf.torch import create_compressed_model, register_default_init_args
|
45 | 44 | from nncf.torch.initialization import PTInitializingDataLoader
|
46 | 45 | from nncf.torch.layer_utils import CompressionParameter
|
47 |
| -from openvino._offline_transformations import apply_moc_transformations, compress_quantize_weights_transformation |
48 | 46 | from PIL import Image
|
49 | 47 | from requests.packages.urllib3.exceptions import InsecureRequestWarning
|
50 | 48 | from torchvision import transforms
|
51 | 49 | from tqdm import tqdm
|
52 | 50 |
|
53 |
| -from optimum.exporters.onnx import export_models, get_stable_diffusion_models_for_export |
54 |
| -from optimum.intel import OVStableDiffusionPipeline |
55 |
| -from optimum.utils import ( |
56 |
| - DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, |
57 |
| - DIFFUSION_MODEL_UNET_SUBFOLDER, |
58 |
| - DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, |
59 |
| - DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, |
60 |
| -) |
| 51 | +from optimum.exporters.openvino import export_from_model |
61 | 52 |
|
62 | 53 |
|
63 | 54 | requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
|
@@ -583,47 +574,6 @@ def get_noise_scheduler(args):
|
583 | 574 | return noise_scheduler
|
584 | 575 |
|
585 | 576 |
|
586 |
| -def export_to_onnx(pipeline, save_dir): |
587 |
| - unet = pipeline.unet |
588 |
| - vae = pipeline.vae |
589 |
| - text_encoder = pipeline.text_encoder |
590 |
| - |
591 |
| - unet.eval().cpu() |
592 |
| - vae.eval().cpu() |
593 |
| - text_encoder.eval().cpu() |
594 |
| - |
595 |
| - ONNX_WEIGHTS_NAME = "model.onnx" |
596 |
| - |
597 |
| - output_names = [ |
598 |
| - os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, ONNX_WEIGHTS_NAME), |
599 |
| - os.path.join(DIFFUSION_MODEL_UNET_SUBFOLDER, ONNX_WEIGHTS_NAME), |
600 |
| - os.path.join(DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ONNX_WEIGHTS_NAME), |
601 |
| - os.path.join(DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, ONNX_WEIGHTS_NAME), |
602 |
| - ] |
603 |
| - |
604 |
| - with torch.no_grad(): |
605 |
| - models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) |
606 |
| - pipeline.save_config(save_dir) |
607 |
| - export_models( |
608 |
| - models_and_onnx_configs=models_and_onnx_configs, output_dir=Path(save_dir), output_names=output_names |
609 |
| - ) |
610 |
| - |
611 |
| - |
612 |
| -def export_to_openvino(pipeline, onnx_dir, save_dir): |
613 |
| - ov_pipe = OVStableDiffusionPipeline.from_pretrained( |
614 |
| - model_id=onnx_dir, |
615 |
| - from_onnx=True, |
616 |
| - model_save_dir=save_dir, |
617 |
| - tokenizer=pipeline.tokenizer, |
618 |
| - scheduler=pipeline.scheduler, |
619 |
| - feature_extractor=pipeline.feature_extractor, |
620 |
| - compile=False, |
621 |
| - ) |
622 |
| - apply_moc_transformations(ov_pipe.unet.model, cf=False) |
623 |
| - compress_quantize_weights_transformation(ov_pipe.unet.model) |
624 |
| - ov_pipe.save_pretrained(save_dir) |
625 |
| - |
626 |
| - |
627 | 577 | class UnetInitDataset(torch.utils.data.Dataset):
|
628 | 578 | def __init__(self, data):
|
629 | 579 | super().__init__()
|
@@ -700,7 +650,7 @@ def get_nncf_config(pipeline, dataloader, args):
|
700 | 650 | "ignored_scopes": [
|
701 | 651 | "{re}.*__add___[0-2]",
|
702 | 652 | "{re}.*layer_norm_0",
|
703 |
| - "{re}.*Attention.*/bmm_0", |
| 653 | + # "{re}.*Attention.*/bmm_0", |
704 | 654 | "{re}.*__truediv__*",
|
705 | 655 | "{re}.*group_norm_0",
|
706 | 656 | "{re}.*mul___[0-2]",
|
@@ -771,11 +721,13 @@ def main():
|
771 | 721 |
|
772 | 722 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
773 | 723 |
|
| 724 | + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) |
| 725 | + |
774 | 726 | accelerator = Accelerator(
|
775 | 727 | gradient_accumulation_steps=args.gradient_accumulation_steps,
|
776 | 728 | mixed_precision=args.mixed_precision,
|
777 | 729 | log_with=args.report_to,
|
778 |
| - logging_dir=logging_dir, |
| 730 | + project_config=accelerator_project_config, |
779 | 731 | )
|
780 | 732 |
|
781 | 733 | logging.basicConfig(
|
@@ -922,7 +874,7 @@ def tokenize_captions(examples, is_train=True):
|
922 | 874 |
|
923 | 875 | with accelerator.main_process_first():
|
924 | 876 | if args.max_train_samples is not None:
|
925 |
| - dataset["train"] = dataset["train"].shuffle(seed=42, buffer_size=args.max_train_samples) |
| 877 | + dataset["train"] = dataset["train"].shuffle(seed=42).select(range(args.max_train_samples)) |
926 | 878 | # Set the training transforms
|
927 | 879 | train_dataset = dataset["train"]
|
928 | 880 |
|
@@ -1132,9 +1084,8 @@ def collate_fn(examples):
|
1132 | 1084 | feature_extractor=pipeline.feature_extractor,
|
1133 | 1085 | )
|
1134 | 1086 |
|
1135 |
| - with tempfile.TemporaryDirectory() as tmpdirname: |
1136 |
| - export_to_onnx(export_pipeline, tmpdirname) |
1137 |
| - export_to_openvino(export_pipeline, tmpdirname, Path(args.output_dir) / "openvino") |
| 1087 | + save_directory = Path(args.output_dir) / "openvino" |
| 1088 | + export_from_model(export_pipeline, output=save_directory, task="stable-diffusion") |
1138 | 1089 |
|
1139 | 1090 |
|
1140 | 1091 | if __name__ == "__main__":
|
|
0 commit comments