Skip to content

Commit e6641b0

Browse files
uniartisanfxmarty
andauthored
Support qwen2 family model (qwen1.5) (#1746)
* Support qwen2 family model (qwen1.5) * update docs * add tests for qwen2 * fix test * ordering --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
1 parent cf82249 commit e6641b0

File tree

9 files changed

+20
-2
lines changed

9 files changed

+20
-2
lines changed

docs/source/exporters/onnx/overview.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
7777
- Phi
7878
- Pix2Struct
7979
- PoolFormer
80+
- Qwen2(Qwen1.5)
8081
- RegNet
8182
- ResNet
8283
- Roberta

optimum/exporters/onnx/model_configs.py

+4
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
241241
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
242242

243243

244+
class Qwen2OnnxConfig(LlamaOnnxConfig):
245+
pass
246+
247+
244248
class GemmaOnnxConfig(LlamaOnnxConfig):
245249
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
246250
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator

optimum/exporters/onnx/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@
8282
"gptj",
8383
"imagegpt",
8484
"llama",
85-
"phi",
8685
"mistral",
86+
"phi",
87+
"qwen2",
8788
}
8889

8990

optimum/exporters/tasks.py

+8
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,14 @@ class TasksManager:
853853
"text-classification",
854854
onnx="OPTOnnxConfig",
855855
),
856+
"qwen2": supported_tasks_mapping(
857+
"feature-extraction",
858+
"feature-extraction-with-past",
859+
"text-generation",
860+
"text-generation-with-past",
861+
"text-classification",
862+
onnx="Qwen2OnnxConfig",
863+
),
856864
"llama": supported_tasks_mapping(
857865
"feature-extraction",
858866
"feature-extraction-with-past",

optimum/onnxruntime/modeling_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def prepare_past_key_values(
338338
if self.model_type == "gemma":
339339
num_attention_heads = self.normalized_config.num_key_value_heads
340340
embed_size_per_head = self.normalized_config.head_dim
341-
elif self.model_type in {"gemma", "mistral", "llama"}:
341+
elif self.model_type in {"mistral", "llama", "qwen2"}:
342342
num_attention_heads = self.normalized_config.num_key_value_heads
343343
else:
344344
num_attention_heads = self.normalized_config.num_attention_heads

optimum/utils/normalized_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ class NormalizedConfigManager:
264264
"whisper": WhisperLikeNormalizedTextConfig,
265265
"xlm-roberta": NormalizedTextConfig,
266266
"yolos": NormalizedVisionConfig,
267+
"qwen2": NormalizedTextConfig,
267268
}
268269

269270
@classmethod

tests/exporters/exporters_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
"pix2struct": "fxmarty/pix2struct-tiny-random",
134134
# "rembert": "google/rembert",
135135
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
136+
"qwen2": "fxmarty/tiny-dummy-qwen2",
136137
"regnet": "hf-internal-testing/tiny-random-RegNetModel",
137138
"resnet": "hf-internal-testing/tiny-random-resnet",
138139
"roberta": "hf-internal-testing/tiny-random-RobertaModel",

tests/onnxruntime/test_modeling.py

+1
Original file line numberDiff line numberDiff line change
@@ -2258,6 +2258,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
22582258
"llama",
22592259
"mistral",
22602260
"mpt",
2261+
"qwen2",
22612262
]
22622263

22632264
FULL_GRID = {

tests/onnxruntime/utils_onnxruntime_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
"perceiver_vision": "hf-internal-testing/tiny-random-vision_perceiver_conv",
131131
"pix2struct": "fxmarty/pix2struct-tiny-random",
132132
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
133+
"qwen2": "fxmarty/tiny-dummy-qwen2",
133134
"resnet": "hf-internal-testing/tiny-random-resnet",
134135
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
135136
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",

0 commit comments

Comments
 (0)