Skip to content

Commit 360ad67

Browse files
better custom config
1 parent 3ac9f99 commit 360ad67

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/openvino/test_export.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,21 @@ def test_export_custom_model(self):
124124
class BertOnnxConfigWithPooler(BertOnnxConfig):
125125
@property
126126
def outputs(self):
127-
common_outputs = {}
128-
common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"}
129-
common_outputs["pooler_output"] = {0: "batch_size"}
127+
if self.task == "feature-extraction-with-pooler":
128+
common_outputs = {}
129+
common_outputs["last_hidden_state"] = {0: "batch_size", 1: "sequence_length"}
130+
common_outputs["pooler_output"] = {0: "batch_size"}
131+
else:
132+
common_outputs = super().outputs
133+
130134
return common_outputs
131135

132136
base_task = "feature-extraction"
133137
custom_task = f"{base_task}-with-pooler"
134138
model_id = "sentence-transformers/all-MiniLM-L6-v2"
135139

136140
config = AutoConfig.from_pretrained(model_id)
137-
custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=base_task)}
141+
custom_export_configs = {"model": BertOnnxConfigWithPooler(config, task=custom_task)}
138142

139143
with TemporaryDirectory() as tmpdirname:
140144
main_export(

0 commit comments

Comments
 (0)