From c40cb2773dc500947c316215dffec4419a48d380 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Tue, 4 Mar 2025 18:09:33 +0000 Subject: [PATCH] [llm_bench] Fix way with relative path of media for json prompts --- tools/llm_bench/llm_bench_utils/model_utils.py | 8 ++++++++ tools/llm_bench/task/image_generation.py | 3 +++ tools/llm_bench/task/visual_language_generation.py | 5 +---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 185a818372..2f885d40a0 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -312,3 +312,11 @@ def init_timestamp(num_iters, prompt_list, prompt_idx_list): p_idx = prompt_idx_list[idx] iter_timestamp[num][p_idx] = {} return iter_timestamp + + +def resolve_media_file_path(file_path, prompt_file_path): + if not file_path: + return file_path + if not (file_path.startswith("http://") or file_path.startswith("https://")): + return os.path.join(os.path.dirname(prompt_file_path), file_path.replace("./", "")) + return file_path diff --git a/tools/llm_bench/task/image_generation.py b/tools/llm_bench/task/image_generation.py index 661070f796..4ac489da70 100644 --- a/tools/llm_bench/task/image_generation.py +++ b/tools/llm_bench/task/image_generation.py @@ -277,6 +277,9 @@ def get_image_prompt(args): image_param_list = parse_json_data.parse_image_json_data(output_data_list) if len(image_param_list) > 0: for image_data in image_param_list: + if args['prompt_file'] is not None and len(args['prompt_file']) > 0: + image_data['media'] = model_utils.resolve_media_file_path(image_data.get("media"), args['prompt_file'][0]) + image_data['mask_image'] = model_utils.resolve_media_file_path(image_data.get("mask_image"), args['prompt_file'][0]) input_image_list.append(image_data) else: input_image_list.append(output_data_list[0]) diff --git a/tools/llm_bench/task/visual_language_generation.py b/tools/llm_bench/task/visual_language_generation.py index c11ec6a066..54b4467c14 100644 --- a/tools/llm_bench/task/visual_language_generation.py +++ b/tools/llm_bench/task/visual_language_generation.py @@ -366,10 +366,7 @@ def get_image_text_prompt(args): if len(vlm_param_list) > 0: for vlm_file in vlm_param_list: if args['prompt_file'] is not None and len(args['prompt_file']) > 0: - media_path = vlm_file["media"] - if not (media_path.startswith("http://") or media_path.startswith("https://")): - media_path = os.path.join(os.path.dirname(args['prompt_file'][0]), media_path.replace("./", "")) - vlm_file['media'] = media_path + vlm_file['media'] = model_utils.resolve_media_file_path(vlm_file.get("media"), args['prompt_file'][0]) vlm_file_list.append(vlm_file) else: vlm_file_list.append(output_data_list)