19
19
from transformers .utils import is_tf_available
20
20
21
21
from optimum .exporters .onnx .config import TextDecoderOnnxConfig , TextDecoderWithPositionIdsOnnxConfig
22
- from optimum .exporters .onnx .model_configs import FalconOnnxConfig , GemmaOnnxConfig , LlamaOnnxConfig , PhiOnnxConfig
22
+ from optimum .exporters .onnx .model_configs import (
23
+ FalconOnnxConfig ,
24
+ GemmaOnnxConfig ,
25
+ LlamaOnnxConfig ,
26
+ MPTOnnxConfig ,
27
+ PhiOnnxConfig ,
28
+ UNetOnnxConfig ,
29
+ VaeDecoderOnnxConfig ,
30
+ VaeEncoderOnnxConfig ,
31
+ )
23
32
from optimum .exporters .tasks import TasksManager
24
33
from optimum .utils import DEFAULT_DUMMY_SHAPES
25
34
from optimum .utils .input_generators import (
35
44
BaichuanModelPatcher ,
36
45
ChatGLMModelPatcher ,
37
46
GemmaModelPatcher ,
47
+ InternLMPatcher ,
38
48
LlamaModelPatcher ,
39
49
MixtralModelPatcher ,
50
+ MPTModelPatcher ,
40
51
Phi3ModelPatcher ,
41
52
QwenModelPatcher ,
42
53
)
@@ -431,6 +442,11 @@ class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
431
442
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
432
443
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
433
444
445
+ def patch_model_for_export (
446
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
447
+ ) -> "ModelPatcher" :
448
+ return InternLMPatcher (self , model , model_kwargs = model_kwargs )
449
+
434
450
435
451
@register_in_tasks_manager ("orion" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
436
452
class OrionOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
@@ -447,6 +463,16 @@ class OlmoOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
447
463
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
448
464
449
465
466
+ @register_in_tasks_manager (
467
+ "mpt" , * ["text-generation" , "text-generation-with-past" , "text-classification" ], library_name = "transformers"
468
+ )
469
+ class MPTOpenVINOConfig (MPTOnnxConfig ):
470
+ def patch_model_for_export (
471
+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
472
+ ) -> "ModelPatcher" :
473
+ return MPTModelPatcher (self , model , model_kwargs = model_kwargs )
474
+
475
+
450
476
@register_in_tasks_manager (
451
477
"phi3" ,
452
478
* [
@@ -510,3 +536,59 @@ class FalconOpenVINOConfig(FalconOnnxConfig):
510
536
OVFalconDummyPastKeyValuesGenerator ,
511
537
) + TextDecoderOnnxConfig .DUMMY_INPUT_GENERATOR_CLASSES
512
538
DUMMY_PKV_GENERATOR_CLASS = OVFalconDummyPastKeyValuesGenerator
539
+
540
+
541
+ @register_in_tasks_manager ("unet" , * ["semantic-segmentation" ], library_name = "diffusers" )
542
+ class UNetOpenVINOConfig (UNetOnnxConfig ):
543
+ @property
544
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
545
+ common_inputs = {
546
+ "sample" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
547
+ "timestep" : {0 : "steps" },
548
+ "encoder_hidden_states" : {0 : "batch_size" , 1 : "sequence_length" },
549
+ }
550
+
551
+ # TODO : add text_image, image and image_embeds
552
+ if getattr (self ._normalized_config , "addition_embed_type" , None ) == "text_time" :
553
+ common_inputs ["text_embeds" ] = {0 : "batch_size" }
554
+ common_inputs ["time_ids" ] = {0 : "batch_size" }
555
+
556
+ if getattr (self ._normalized_config , "time_cond_proj_dim" , None ) is not None :
557
+ common_inputs ["timestep_cond" ] = {0 : "batch_size" }
558
+ return common_inputs
559
+
560
+ @property
561
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
562
+ return {
563
+ "out_sample" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
564
+ }
565
+
566
+
567
+ @register_in_tasks_manager ("vae-encoder" , * ["semantic-segmentation" ], library_name = "diffusers" )
568
+ class VaeEncoderOpenVINOConfig (VaeEncoderOnnxConfig ):
569
+ @property
570
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
571
+ return {
572
+ "sample" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
573
+ }
574
+
575
+ @property
576
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
577
+ return {
578
+ "latent_sample" : {0 : "batch_size" , 2 : "height_latent" , 3 : "width_latent" },
579
+ }
580
+
581
+
582
+ @register_in_tasks_manager ("vae-decoder" , * ["semantic-segmentation" ], library_name = "diffusers" )
583
+ class VaeDecoderOpenVINOConfig (VaeDecoderOnnxConfig ):
584
+ @property
585
+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
586
+ return {
587
+ "latent_sample" : {0 : "batch_size" , 2 : "height_latent" , 3 : "width_latent" },
588
+ }
589
+
590
+ @property
591
+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
592
+ return {
593
+ "sample" : {0 : "batch_size" , 2 : "height" , 3 : "width" },
594
+ }
0 commit comments