File tree 1 file changed +8
-4
lines changed
1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -124,17 +124,21 @@ def test_export_custom_model(self):
124
124
class BertOnnxConfigWithPooler (BertOnnxConfig ):
125
125
@property
126
126
def outputs (self ):
127
- common_outputs = {}
128
- common_outputs ["last_hidden_state" ] = {0 : "batch_size" , 1 : "sequence_length" }
129
- common_outputs ["pooler_output" ] = {0 : "batch_size" }
127
+ if self .task == "feature-extraction-with-pooler" :
128
+ common_outputs = {}
129
+ common_outputs ["last_hidden_state" ] = {0 : "batch_size" , 1 : "sequence_length" }
130
+ common_outputs ["pooler_output" ] = {0 : "batch_size" }
131
+ else :
132
+ common_outputs = super ().outputs
133
+
130
134
return common_outputs
131
135
132
136
base_task = "feature-extraction"
133
137
custom_task = f"{ base_task } -with-pooler"
134
138
model_id = "sentence-transformers/all-MiniLM-L6-v2"
135
139
136
140
config = AutoConfig .from_pretrained (model_id )
137
- custom_export_configs = {"model" : BertOnnxConfigWithPooler (config , task = base_task )}
141
+ custom_export_configs = {"model" : BertOnnxConfigWithPooler (config , task = custom_task )}
138
142
139
143
with TemporaryDirectory () as tmpdirname :
140
144
main_export (
You can’t perform that action at this time.
0 commit comments