@@ -141,40 +141,24 @@ def create_text_gen_model(model_path, device, **kwargs):
141
141
if not model_path_existed :
142
142
raise RuntimeError (f'==Failure ==: model path:{ model_path } does not exist' )
143
143
else :
144
- if model_type in ['replit' , 'codegen2' , 'chatglm' ]:
145
- start = time .perf_counter ()
146
- ov_model = model_class .from_pretrained (
147
- model_path ,
148
- device = device ,
149
- ov_config = ov_config ,
150
- config = AutoConfig .from_pretrained (model_path , trust_remote_code = True ),
151
- stateful = kwargs .get ("stateful" , None )
152
- )
153
- end = time .perf_counter ()
154
- elif model_type in ['falcon' , "mpt" ]:
155
- start = time .perf_counter ()
156
- ov_model = model_class .from_pretrained (
157
- model_path ,
158
- device = device ,
159
- ov_config = ov_config ,
160
- stateful = kwargs .get ("stateful" , None ),
161
- trust_remote_code = False
162
- )
163
- end = time .perf_counter ()
164
- else :
165
- start = time .perf_counter ()
166
- config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
167
- ov_model = model_class .from_pretrained (
168
- model_path ,
169
- device = device ,
170
- ov_config = ov_config ,
171
- config = config ,
172
- compile = False ,
173
- stateful = kwargs .get ("stateful" , None )
174
- )
175
- if not isinstance (ov_model , OV_MODEL_CLASSES_MAPPING ['t5' ]):
176
- patch_inter_processing_and_compile (ov_model , ** kwargs )
177
- end = time .perf_counter ()
144
+ remote_code = False
145
+ try :
146
+ model_config = AutoConfig .from_pretrained (model_path )
147
+ except Exception :
148
+ model_config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
149
+ remote_code = True
150
+ start = time .perf_counter ()
151
+ ov_model = model_class .from_pretrained (
152
+ model_path ,
153
+ device = device ,
154
+ ov_config = ov_config ,
155
+ config = model_config ,
156
+ stateful = kwargs .get ("stateful" , None ),
157
+ trust_remote_code = remote_code
158
+ )
159
+ if not isinstance (ov_model , OV_MODEL_CLASSES_MAPPING ['t5' ]):
160
+ patch_inter_processing_and_compile (ov_model , ** kwargs )
161
+ end = time .perf_counter ()
178
162
if kwargs ['num_beams' ] > 1 :
179
163
bench_hook = utils .hook_beam_search .BeamSearchHook ()
180
164
else :
0 commit comments