Skip to content

Commit 9b6467c

Browse files
committed
Added support of OV GenAI for SD
1 parent 7cfedba commit 9b6467c

File tree

4 files changed

+136
-52
lines changed

4 files changed

+136
-52
lines changed

tools/who_what_benchmark/whowhatbench/registry.py

+4
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ def score(self, model, **kwargs):
3535
@abstractmethod
3636
def worst_examples(self, top_k: int = 5, metric="similarity"):
3737
pass
38+
39+
@abstractmethod
40+
def get_generation_fn(self):
41+
raise NotImplementedError("generation_fn should be returned")

tools/who_what_benchmark/whowhatbench/text2image_evaluator.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tqdm import tqdm
66
from transformers import set_seed
77
import torch
8+
import openvino_genai
89

910
from .registry import register_evaluator, BaseEvaluator
1011

@@ -26,6 +27,17 @@
2627
}
2728

2829

30+
class Generator(openvino_genai.Generator):
31+
def __init__(self, seed, rng, mu=0.0, sigma=1.0):
32+
openvino_genai.Generator.__init__(self)
33+
self.mu = mu
34+
self.sigma = sigma
35+
self.rng = rng
36+
37+
def next(self):
38+
return torch.normal(torch.tensor(self.mu), self.sigma, generator=self.rng)
39+
40+
2941
@register_evaluator("text-to-image")
3042
class Text2ImageEvaluator(BaseEvaluator):
3143
def __init__(
@@ -41,6 +53,7 @@ def __init__(
4153
num_samples=None,
4254
gen_image_fn=None,
4355
seed=42,
56+
is_genai=False,
4457
) -> None:
4558
assert (
4659
base_model is not None or gt_data is not None
@@ -57,13 +70,19 @@ def __init__(
5770
self.similarity = ImageSimilarity(similarity_model_id)
5871
self.last_cmp = None
5972
self.gt_dir = os.path.dirname(gt_data)
73+
self.generation_fn = gen_image_fn
74+
self.is_genai = is_genai
75+
6076
if base_model:
6177
self.gt_data = self._generate_data(
6278
base_model, gen_image_fn, os.path.join(self.gt_dir, "reference")
6379
)
6480
else:
6581
self.gt_data = pd.read_csv(gt_data, keep_default_na=False)
6682

83+
def get_generation_fn(self):
84+
return self.generation_fn
85+
6786
def dump_gt(self, csv_name: str):
6887
self.gt_data.to_csv(csv_name)
6988

@@ -99,13 +118,15 @@ def worst_examples(self, top_k: int = 5, metric="similarity"):
99118
return res
100119

101120
def _generate_data(self, model, gen_image_fn=None, image_dir="reference"):
121+
model.resolution = self.resolution
102122
if hasattr(model, "reshape") and self.resolution is not None:
103-
model.reshape(
104-
batch_size=1,
105-
height=self.resolution[0],
106-
width=self.resolution[1],
107-
num_images_per_prompt=1,
108-
)
123+
if gen_image_fn is None:
124+
model.reshape(
125+
batch_size=1,
126+
height=self.resolution[0],
127+
width=self.resolution[1],
128+
num_images_per_prompt=1,
129+
)
109130

110131
def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
111132
output = model(
@@ -118,7 +139,7 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
118139
)
119140
return output.images[0]
120141

121-
gen_image_fn = gen_image_fn or default_gen_image_fn
142+
generation_fn = gen_image_fn or default_gen_image_fn
122143

123144
if self.test_data:
124145
if isinstance(self.test_data, str):
@@ -144,13 +165,16 @@ def default_gen_image_fn(model, prompt, num_inference_steps, generator=None):
144165

145166
if not os.path.exists(image_dir):
146167
os.makedirs(image_dir)
168+
169+
print(gen_image_fn)
147170
for i, prompt in tqdm(enumerate(prompts), desc="Evaluate pipeline"):
148171
set_seed(self.seed)
149-
image = gen_image_fn(
172+
rng = rng.manual_seed(self.seed)
173+
image = generation_fn(
150174
model,
151175
prompt,
152176
self.num_inference_steps,
153-
generator=rng.manual_seed(self.seed),
177+
generator=Generator(self.seed, rng) if self.is_genai else rng
154178
)
155179
image_path = os.path.join(image_dir, f"{i}.png")
156180
image.save(image_path)

tools/who_what_benchmark/whowhatbench/text_evaluator.py

+4
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
self.generation_config = generation_config
122122
self.generation_config_base = generation_config
123123
self.seqs_per_request = seqs_per_request
124+
self.generation_fn = gen_answer_fn
124125
if self.generation_config is not None:
125126
assert self.seqs_per_request is not None
126127

@@ -151,6 +152,9 @@ def __init__(
151152

152153
self.last_cmp = None
153154

155+
def get_generation_fn(self):
156+
return self.generation_fn
157+
154158
def dump_gt(self, csv_name: str):
155159
self.gt_data.to_csv(csv_name)
156160

tools/who_what_benchmark/whowhatbench/wwb.py

+95-43
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import json
55
import pandas as pd
6+
from PIL import Image
67
import logging
78
from datasets import load_dataset
89
from diffusers import DiffusionPipeline
@@ -35,9 +36,14 @@ class GenAIModelWrapper:
3536
A helper class to store additional attributes for GenAI models
3637
"""
3738

38-
def __init__(self, model, model_dir):
39+
def __init__(self, model, model_dir, model_type):
3940
self.model = model
40-
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
41+
self.model_type = model_type
42+
43+
if model_type == "text":
44+
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
45+
elif model_type == "text-to-image":
46+
self.config = DiffusionPipeline.load_config(model_dir, trust_remote_code=True)
4147

4248
def __getattr__(self, attr):
4349
if attr in self.__dict__:
@@ -53,40 +59,41 @@ def load_text_genai_pipeline(model_dir, device="CPU"):
5359
logger.error("Failed to import openvino_genai package. Please install it.")
5460
exit(-1)
5561
logger.info("Using OpenVINO GenAI API")
56-
return GenAIModelWrapper(openvino_genai.LLMPipeline(model_dir, device), model_dir)
62+
return GenAIModelWrapper(openvino_genai.LLMPipeline(model_dir, device), model_dir, "text")
5763

5864

5965
def load_text_model(
6066
model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False
6167
):
62-
if use_hf:
63-
logger.info("Using HF Transformers API")
64-
return AutoModelForCausalLM.from_pretrained(
65-
model_id, trust_remote_code=True, device_map=device.lower()
66-
)
67-
68-
if use_genai:
69-
return load_text_genai_pipeline(model_id, device)
70-
7168
if ov_config:
7269
with open(ov_config) as f:
7370
ov_options = json.load(f)
7471
else:
7572
ov_options = None
76-
try:
77-
model = OVModelForCausalLM.from_pretrained(
78-
model_id, trust_remote_code=True, device=device, ov_config=ov_options
79-
)
80-
except ValueError:
81-
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
82-
model = OVModelForCausalLM.from_pretrained(
83-
model_id,
84-
config=config,
85-
trust_remote_code=True,
86-
use_cache=True,
87-
device=device,
88-
ov_config=ov_options,
73+
74+
if use_hf:
75+
logger.info("Using HF Transformers API")
76+
model = AutoModelForCausalLM.from_pretrained(
77+
model_id, trust_remote_code=True, device_map=device.lower()
8978
)
79+
elif use_genai:
80+
model = load_text_genai_pipeline(model_id, device)
81+
else:
82+
try:
83+
model = OVModelForCausalLM.from_pretrained(
84+
model_id, trust_remote_code=True, device=device, ov_config=ov_options
85+
)
86+
except ValueError:
87+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
88+
model = OVModelForCausalLM.from_pretrained(
89+
model_id,
90+
config=config,
91+
trust_remote_code=True,
92+
use_cache=True,
93+
device=device,
94+
ov_config=ov_options,
95+
)
96+
9097
return model
9198

9299

@@ -95,6 +102,20 @@ def load_text_model(
95102
}
96103

97104

105+
def load_text2image_genai_pipeline(model_dir, device="CPU"):
106+
try:
107+
import openvino_genai
108+
except ImportError:
109+
logger.error("Failed to import openvino_genai package. Please install it.")
110+
exit(-1)
111+
logger.info("Using OpenVINO GenAI API")
112+
return GenAIModelWrapper(
113+
openvino_genai.Text2ImagePipeline(model_dir, device),
114+
model_dir,
115+
"text-to-image"
116+
)
117+
118+
98119
def load_text2image_model(
99120
model_type, model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False
100121
):
@@ -104,25 +125,28 @@ def load_text2image_model(
104125
else:
105126
ov_options = None
106127

107-
if use_hf:
108-
return DiffusionPipeline.from_pretrained(model_id, trust_remote_code=True)
128+
if use_genai:
129+
model = load_text2image_genai_pipeline(model_id, device)
130+
elif use_hf:
131+
model = DiffusionPipeline.from_pretrained(model_id, trust_remote_code=True)
132+
else:
133+
TEXT2IMAGEPipeline = TEXT2IMAGE_TASK2CLASS[model_type]
109134

110-
TEXT2IMAGEPipeline = TEXT2IMAGE_TASK2CLASS[model_type]
135+
try:
136+
model = TEXT2IMAGEPipeline.from_pretrained(
137+
model_id, trust_remote_code=True, device=device, ov_config=ov_options
138+
)
139+
except ValueError:
140+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
141+
model = TEXT2IMAGEPipeline.from_pretrained(
142+
model_id,
143+
config=config,
144+
trust_remote_code=True,
145+
use_cache=True,
146+
device=device,
147+
ov_config=ov_options,
148+
)
111149

112-
try:
113-
model = TEXT2IMAGEPipeline.from_pretrained(
114-
model_id, trust_remote_code=True, device=device, ov_config=ov_options
115-
)
116-
except ValueError:
117-
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
118-
model = TEXT2IMAGEPipeline.from_pretrained(
119-
model_id,
120-
config=config,
121-
trust_remote_code=True,
122-
use_cache=True,
123-
device=device,
124-
ov_config=ov_options,
125-
)
126150
return model
127151

128152

@@ -278,6 +302,18 @@ def parse_args():
278302
action="store_true",
279303
help="Use LLMPipeline from transformers library to instantiate the model.",
280304
)
305+
parser.add_argument(
306+
"--image-size",
307+
type=int,
308+
default=512,
309+
help="Text-to-image specific parameter that defines the image resolution.",
310+
)
311+
parser.add_argument(
312+
"--num-inference-steps",
313+
type=int,
314+
default=4,
315+
help="Text-to-image specific parameter that defines the number of denoising steps.",
316+
)
281317

282318
return parser.parse_args()
283319

@@ -340,6 +376,18 @@ def genai_gen_answer(model, tokenizer, question, max_new_tokens, skip_question):
340376
return out
341377

342378

379+
def genai_gen_image(model, prompt, num_inference_steps, generator=None):
380+
image_tensor = model.generate(
381+
prompt,
382+
width=model.resolution[0],
383+
height=model.resolution[1],
384+
num_inference_steps=num_inference_steps,
385+
random_generator=generator
386+
)
387+
image = Image.fromarray(image_tensor.data[0])
388+
return image
389+
390+
343391
def get_evaluator(base_model, args):
344392
# config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
345393
# task = TasksManager.infer_task_from_model(config._name_or_path)
@@ -368,6 +416,10 @@ def get_evaluator(base_model, args):
368416
gt_data=args.gt_data,
369417
test_data=prompts,
370418
num_samples=args.num_samples,
419+
resolution=(args.image_size, args.image_size),
420+
num_inference_steps=args.num_inference_steps,
421+
gen_image_fn=genai_gen_image if args.genai else None,
422+
is_genai=args.genai
371423
)
372424
else:
373425
raise ValueError(f"Unsupported task: {task}")
@@ -446,7 +498,7 @@ def main():
446498
args.genai,
447499
)
448500
all_metrics_per_question, all_metrics = evaluator.score(
449-
target_model, genai_gen_answer if args.genai else None
501+
target_model, evaluator.get_generation_fn() if args.genai else None
450502
)
451503
logger.info("Metrics for model: %s", args.target_model)
452504
logger.info(all_metrics)

0 commit comments

Comments
 (0)