Skip to content

Commit 23cb607

Browse files
authored
[llm_bench] Fix way with relative path of media for json prompts (#1843)
1 parent 0c94961 commit 23cb607

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

tools/llm_bench/llm_bench_utils/model_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,11 @@ def init_timestamp(num_iters, prompt_list, prompt_idx_list):
312312
p_idx = prompt_idx_list[idx]
313313
iter_timestamp[num][p_idx] = {}
314314
return iter_timestamp
315+
316+
317+
def resolve_media_file_path(file_path, prompt_file_path):
318+
if not file_path:
319+
return file_path
320+
if not (file_path.startswith("http://") or file_path.startswith("https://")):
321+
return os.path.join(os.path.dirname(prompt_file_path), file_path.replace("./", ""))
322+
return file_path

tools/llm_bench/task/image_generation.py

+3
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ def get_image_prompt(args):
277277
image_param_list = parse_json_data.parse_image_json_data(output_data_list)
278278
if len(image_param_list) > 0:
279279
for image_data in image_param_list:
280+
if args['prompt_file'] is not None and len(args['prompt_file']) > 0:
281+
image_data['media'] = model_utils.resolve_media_file_path(image_data.get("media"), args['prompt_file'][0])
282+
image_data['mask_image'] = model_utils.resolve_media_file_path(image_data.get("mask_image"), args['prompt_file'][0])
280283
input_image_list.append(image_data)
281284
else:
282285
input_image_list.append(output_data_list[0])

tools/llm_bench/task/visual_language_generation.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,7 @@ def get_image_text_prompt(args):
366366
if len(vlm_param_list) > 0:
367367
for vlm_file in vlm_param_list:
368368
if args['prompt_file'] is not None and len(args['prompt_file']) > 0:
369-
media_path = vlm_file["media"]
370-
if not (media_path.startswith("http://") or media_path.startswith("https://")):
371-
media_path = os.path.join(os.path.dirname(args['prompt_file'][0]), media_path.replace("./", ""))
372-
vlm_file['media'] = media_path
369+
vlm_file['media'] = model_utils.resolve_media_file_path(vlm_file.get("media"), args['prompt_file'][0])
373370
vlm_file_list.append(vlm_file)
374371
else:
375372
vlm_file_list.append(output_data_list)

0 commit comments

Comments
 (0)