Skip to content

Commit f4d50db

Browse files
authored
fix export for phi1.5 (#400)
1 parent e9825f2 commit f4d50db

File tree

2 files changed

+57
-34
lines changed

2 files changed

+57
-34
lines changed

llm_bench/python/convert.py

+39
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,44 @@ def convert_falcon(args):
12011201
unpatch_gptq(cuda, post_init)
12021202

12031203

1204+
def convert_phi(args):
1205+
trust_remote_code = False
1206+
try:
1207+
config = AutoConfig.from_pretrained(args.model_id)
1208+
except Exception:
1209+
config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True)
1210+
trust_remote_code = True
1211+
cuda, post_init = patch_gptq(config)
1212+
model_kwargs = {}
1213+
if trust_remote_code:
1214+
model_kwargs["trust_remote_code"] = trust_remote_code
1215+
precision = args.precision
1216+
compression_only = (
1217+
args.compress_weights
1218+
and not args.force_convert
1219+
and not is_torch_compression(args)
1220+
and is_ov_model_provided(args.model_id, args.output_dir, args.precision)
1221+
)
1222+
if post_init is not None:
1223+
model_kwargs["torch_dtype"] = torch.float32
1224+
pt_model = None
1225+
gptq_applied = is_gptq(config)
1226+
precision = precision if not gptq_applied else GPTQ_DIR.format(precision=args.precision)
1227+
if not compression_only:
1228+
pt_model = AutoModelForCausalLM.from_pretrained(
1229+
args.model_id,
1230+
config=AutoConfig.from_pretrained(args.model_id),
1231+
**model_kwargs,
1232+
)
1233+
pt_model.config.use_cache = True
1234+
pt_model.eval()
1235+
1236+
convert_optimum_causallm_base(pt_model, args, config, compression_only)
1237+
1238+
if post_init is not None:
1239+
unpatch_gptq(cuda, post_init)
1240+
1241+
12041242
def convert_baichaun(args):
12051243
config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True)
12061244
cuda, post_init = patch_gptq(config)
@@ -1304,6 +1342,7 @@ def convert_aquilachat(args):
13041342
"lcm": convert_lcm,
13051343
"ldm": convert_ldm_super_res,
13061344
"mpt": convert_mpt,
1345+
"phi-": convert_phi,
13071346
"replit": convert_mpt,
13081347
"chatglm2": convert_causal_lm,
13091348
"chatglm3": convert_causal_lm,

llm_bench/python/utils/ov_utils.py

+18-34
Original file line numberDiff line numberDiff line change
@@ -141,40 +141,24 @@ def create_text_gen_model(model_path, device, **kwargs):
141141
if not model_path_existed:
142142
raise RuntimeError(f'==Failure ==: model path:{model_path} does not exist')
143143
else:
144-
if model_type in ['replit', 'codegen2', 'chatglm']:
145-
start = time.perf_counter()
146-
ov_model = model_class.from_pretrained(
147-
model_path,
148-
device=device,
149-
ov_config=ov_config,
150-
config=AutoConfig.from_pretrained(model_path, trust_remote_code=True),
151-
stateful=kwargs.get("stateful", None)
152-
)
153-
end = time.perf_counter()
154-
elif model_type in ['falcon', "mpt"]:
155-
start = time.perf_counter()
156-
ov_model = model_class.from_pretrained(
157-
model_path,
158-
device=device,
159-
ov_config=ov_config,
160-
stateful=kwargs.get("stateful", None),
161-
trust_remote_code=False
162-
)
163-
end = time.perf_counter()
164-
else:
165-
start = time.perf_counter()
166-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
167-
ov_model = model_class.from_pretrained(
168-
model_path,
169-
device=device,
170-
ov_config=ov_config,
171-
config=config,
172-
compile=False,
173-
stateful=kwargs.get("stateful", None)
174-
)
175-
if not isinstance(ov_model, OV_MODEL_CLASSES_MAPPING['t5']):
176-
patch_inter_processing_and_compile(ov_model, **kwargs)
177-
end = time.perf_counter()
144+
remote_code = False
145+
try:
146+
model_config = AutoConfig.from_pretrained(model_path)
147+
except Exception:
148+
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
149+
remote_code = True
150+
start = time.perf_counter()
151+
ov_model = model_class.from_pretrained(
152+
model_path,
153+
device=device,
154+
ov_config=ov_config,
155+
config=model_config,
156+
stateful=kwargs.get("stateful", None),
157+
trust_remote_code=remote_code
158+
)
159+
if not isinstance(ov_model, OV_MODEL_CLASSES_MAPPING['t5']):
160+
patch_inter_processing_and_compile(ov_model, **kwargs)
161+
end = time.perf_counter()
178162
if kwargs['num_beams'] > 1:
179163
bench_hook = utils.hook_beam_search.BeamSearchHook()
180164
else:

0 commit comments

Comments
 (0)