@@ -74,7 +74,7 @@ def init_model_configs():
74
74
75
75
76
76
@register_in_tasks_manager ("baichuan" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
77
- class BaichaunOpenVINOConfig (TextDecoderOnnxConfig ):
77
+ class BaichaunOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
78
78
DEFAULT_ONNX_OPSET = 13
79
79
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (
80
80
num_layers = "num_hidden_layers" , num_attention_heads = "num_attention_heads" , hidden_size = "hidden_size"
@@ -400,3 +400,21 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
400
400
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , MistralDummyPastKeyValuesGenerator )
401
401
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
402
402
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
403
+
404
+
405
+ @register_in_tasks_manager ("internlm2" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
406
+ class InternLM2OpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
407
+ DEFAULT_ONNX_OPSET = 14
408
+
409
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , MistralDummyPastKeyValuesGenerator )
410
+ DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
411
+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
412
+
413
+
414
+ @register_in_tasks_manager ("orion" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers" )
415
+ class OrionOpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
416
+ DEFAULT_ONNX_OPSET = 14
417
+
418
+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , MistralDummyPastKeyValuesGenerator )
419
+ DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
420
+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
0 commit comments