Skip to content

Commit 9338cad

Browse files
Added reproducer
1 parent 683133f commit 9338cad

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

optimum/intel/openvino/modeling_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def _reshape(
471471
height: int = None,
472472
width: int = None,
473473
):
474+
return model
474475
shapes = {}
475476
for inputs in model.inputs:
476477
shapes[inputs] = inputs.get_partial_shape()

reshape_reproducer.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import nncf
2+
import numpy as np
3+
import logging
4+
import tempfile
5+
6+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, default_data_collator
7+
from datasets import load_dataset
8+
import evaluate
9+
from optimum.intel import OVModelForSequenceClassification, OVConfig, OVTrainer
10+
11+
12+
nncf.nncf_logger.setLevel(logging.ERROR)
13+
14+
15+
def reshape_model(model):
16+
shapes = {}
17+
for inputs in model.inputs:
18+
shapes[inputs] = inputs.get_partial_shape()
19+
shapes[inputs][0] = -1
20+
shapes[inputs][1] = -1
21+
model.reshape(shapes)
22+
23+
24+
def get_num_fqs(model):
25+
num_fake_quantize = 0
26+
for node in model.get_ops():
27+
if "FakeQuantize" in node.get_type_name():
28+
num_fake_quantize += 1
29+
return num_fake_quantize
30+
31+
32+
model_id = "distilbert-base-uncased"
33+
for _ in range(5):
34+
# n_samples = 16
35+
n_samples = (np.random.randint(1000) % 16) + 1
36+
37+
model = AutoModelForSequenceClassification.from_pretrained(model_id)
38+
tokenizer = AutoTokenizer.from_pretrained(model_id)
39+
ov_config = OVConfig()
40+
dataset = load_dataset("glue", "sst2")
41+
dataset = dataset.map(
42+
lambda examples: tokenizer(examples["sentence"], padding="max_length", max_length=128), batched=True
43+
)
44+
train_dataset = dataset["train"].select(range(n_samples))
45+
eval_dataset = dataset["validation"].select(range(n_samples))
46+
metric = evaluate.load("glue", "sst2")
47+
with tempfile.TemporaryDirectory() as tmp_dir:
48+
trainer = OVTrainer(
49+
model=model,
50+
ov_config=ov_config,
51+
task="sequence-classification",
52+
args=TrainingArguments(tmp_dir, num_train_epochs=1.0, do_train=True, do_eval=True),
53+
train_dataset=train_dataset,
54+
eval_dataset=eval_dataset,
55+
compute_metrics=lambda p: metric.compute(predictions=np.argmax(p.predictions, 1), references=p.label_ids),
56+
tokenizer=tokenizer,
57+
data_collator=default_data_collator,
58+
)
59+
trainer.train()
60+
trainer.evaluate()
61+
trainer.save_model()
62+
63+
ov_model = OVModelForSequenceClassification.from_pretrained(tmp_dir)
64+
fqs_before_reshape = get_num_fqs(ov_model.model)
65+
reshape_model(ov_model.model)
66+
fqs_after_reshape = get_num_fqs(ov_model.model)
67+
print(f"Number of FQ nodes before reshape: {fqs_before_reshape}, after reshape: {fqs_after_reshape}")

0 commit comments

Comments
 (0)