Skip to content

Commit c513437

Browse files
authored
Add transformers 4.36 tests (#2085)
* add transformers 4.36 tests * add test depending on tranformers version * add min transformers required version for gemma * update macos * fix whisper test * add opt * fix mpt * add comment * add granite testwhen supported by transformers
1 parent e8b0332 commit c513437

File tree

4 files changed

+33
-22
lines changed

4 files changed

+33
-22
lines changed

.github/workflows/test_onnxruntime.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
transformers-version: ["latest"]
21-
os: [ubuntu-20.04, windows-2019, macos-13]
21+
os: [ubuntu-20.04, windows-2019, macos-15]
2222
include:
23+
- transformers-version: "4.36.*"
24+
os: ubuntu-20.04
2325
- transformers-version: "4.45.*"
2426
os: ubuntu-20.04
2527

optimum/exporters/onnx/model_configs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ class Qwen2OnnxConfig(LlamaOnnxConfig):
295295
class GemmaOnnxConfig(LlamaOnnxConfig):
296296
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
297297
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
298-
pass
298+
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
299299

300300

301301
class GraniteOnnxConfig(LlamaOnnxConfig):
@@ -348,6 +348,8 @@ def patch_model_for_export(
348348
class MPTOnnxConfig(TextDecoderOnnxConfig):
349349
# MPT does not require position_ids input.
350350
DEFAULT_ONNX_OPSET = 13
351+
# TODO: fix inference for transformers < v4.41 for beam_search > 1
352+
MIN_TRANSFORMERS_VERSION = version.parse("4.41.0")
351353
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
352354
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
353355
)

setup.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"datasets>=1.2.1",
5555
"evaluate",
5656
"protobuf>=3.20.1",
57-
"transformers<4.47.0",
57+
"transformers>=4.36,<4.47.0",
5858
],
5959
"onnxruntime-gpu": [
6060
"onnx",
@@ -63,19 +63,19 @@
6363
"evaluate",
6464
"protobuf>=3.20.1",
6565
"accelerate", # ORTTrainer requires it.
66-
"transformers<4.47.0",
66+
"transformers>=4.36,<4.47.0",
6767
],
6868
"exporters": [
6969
"onnx",
7070
"onnxruntime",
7171
"timm",
72-
"transformers<4.47.0",
72+
"transformers>=4.36,<4.47.0",
7373
],
7474
"exporters-gpu": [
7575
"onnx",
7676
"onnxruntime-gpu",
7777
"timm",
78-
"transformers<4.47.0",
78+
"transformers>=4.36,<4.47.0",
7979
],
8080
"exporters-tf": [
8181
"tensorflow>=2.4,<=2.12.1",
@@ -86,7 +86,7 @@
8686
"h5py",
8787
"numpy<1.24.0",
8888
"datasets<=2.16",
89-
"transformers>=4.26,<4.38",
89+
"transformers>=4.36,<4.38",
9090
],
9191
"diffusers": ["diffusers"],
9292
"intel": "optimum-intel>=1.18.0",

tests/onnxruntime/test_modeling.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -2318,21 +2318,28 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
23182318
"bloom",
23192319
"codegen",
23202320
"falcon",
2321-
"gemma",
23222321
"gpt2",
23232322
"gpt_bigcode",
23242323
"gpt_neo",
23252324
"gpt_neox",
23262325
"gptj",
2327-
"granite",
23282326
"llama",
23292327
"mistral",
2330-
"mpt",
23312328
"opt",
23322329
]
23332330

2334-
if check_if_transformers_greater("4.40"):
2335-
SUPPORTED_ARCHITECTURES.extend(["gemma", "phi3", "qwen2"])
2331+
if check_if_transformers_greater("4.37"):
2332+
SUPPORTED_ARCHITECTURES.append("qwen2")
2333+
2334+
if check_if_transformers_greater("4.38"):
2335+
SUPPORTED_ARCHITECTURES.append("gemma")
2336+
2337+
# TODO: fix "mpt" for which inference fails for transformers < v4.41
2338+
if check_if_transformers_greater("4.41"):
2339+
SUPPORTED_ARCHITECTURES.extend(["phi3", "mpt"])
2340+
2341+
if check_if_transformers_greater("4.45"):
2342+
SUPPORTED_ARCHITECTURES.append("granite")
23362343

23372344
FULL_GRID = {
23382345
"model_arch": SUPPORTED_ARCHITECTURES,
@@ -2445,7 +2452,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
24452452
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
24462453
transformers_model = transformers_model.eval()
24472454
tokenizer = get_preprocessor(model_id)
2448-
tokens = tokenizer("This is a sample output", return_tensors="pt")
2455+
tokens = tokenizer("This is a sample input", return_tensors="pt")
24492456
position_ids = None
24502457
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
24512458
input_shape = tokens["input_ids"].shape
@@ -2467,7 +2474,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
24672474
# Compare batched generation.
24682475
tokenizer.pad_token_id = tokenizer.eos_token_id
24692476
tokenizer.padding_side = "left"
2470-
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
2477+
tokens = tokenizer(["This is", "This is a sample input"], return_tensors="pt", padding=True)
24712478
onnx_model.generation_config.eos_token_id = None
24722479
transformers_model.generation_config.eos_token_id = None
24732480
onnx_model.config.eos_token_id = None
@@ -4598,14 +4605,14 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str):
45984605
)
45994606

46004607
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
4601-
self.assertEqual(
4602-
outputs_model_with_pkv.shape[1],
4603-
self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1,
4604-
)
4605-
self.assertEqual(
4606-
outputs_model_without_pkv.shape[1],
4607-
self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1,
4608-
)
4608+
4609+
if model_arch == "whisper" and check_if_transformers_greater("4.43"):
4610+
gen_length = self.GENERATION_LENGTH + 2
4611+
else:
4612+
gen_length = self.GENERATION_LENGTH + 1
4613+
4614+
self.assertEqual(outputs_model_with_pkv.shape[1], gen_length)
4615+
self.assertEqual(outputs_model_without_pkv.shape[1], gen_length)
46094616

46104617
self.GENERATION_LENGTH = generation_length
46114618
if os.environ.get("TEST_LEVEL", 0) == "1":

0 commit comments

Comments
 (0)