Skip to content

Commit eecf70f

Browse files
authored
Added support of HF and GenAI models into CLI (openvinotoolkit#887)
1 parent de77f96 commit eecf70f

File tree

6 files changed

+135
-35
lines changed

6 files changed

+135
-35
lines changed

.github/workflows/llm_bench-python.yml

+6-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ jobs:
4040
python -m pip install --upgrade pip
4141
python -m pip install flake8 pytest black
4242
GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.LLM_BENCH_PYPATH }}/requirements.txt
43-
pip install openvino-nightly
43+
python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url
44+
https://storage.openvinotoolkit.org/simple/wheels/nightly
4445
GIT_CLONE_PROTECTION_ACTIVE=false pip install -r ${{ env.WWB_PATH }}/requirements.txt
4546
GIT_CLONE_PROTECTION_ACTIVE=false pip install ${{ env.WWB_PATH }}
4647

@@ -73,7 +74,7 @@ jobs:
7374
python ./llm_bench/python/benchmark.py -m ./ov_models/tiny-sd/pytorch/dldt/FP16/ -pf ./llm_bench/python/prompts/stable-diffusion.jsonl -d cpu -n 1
7475
- name: WWB Tests
7576
run: |
76-
python -m pytest ./llm_bench/python/who_what_benchmark/tests
77+
python -m pytest llm_bench/python/who_what_benchmark/tests
7778
stateful:
7879
runs-on: ubuntu-20.04
7980
steps:
@@ -85,12 +86,13 @@ jobs:
8586
run: |
8687
GIT_CLONE_PROTECTION_ACTIVE=false python -m pip install -r llm_bench/python/requirements.txt
8788
python -m pip uninstall --yes openvino
88-
python -m pip install openvino-nightly
89+
python -m pip install -U --pre openvino openvino-tokenizers openvino-genai --extra-index-url
90+
https://storage.openvinotoolkit.org/simple/wheels/nightly
8991
python llm_bench/python/convert.py --model_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir . --stateful
9092
grep beam_idx pytorch/dldt/FP32/openvino_model.xml
9193
- name: WWB Tests
9294
run: |
9395
GIT_CLONE_PROTECTION_ACTIVE=false pip install -r llm_bench/python/who_what_benchmark/requirements.txt
9496
GIT_CLONE_PROTECTION_ACTIVE=false pip install llm_bench/python/who_what_benchmark/
9597
pip install pytest
96-
python -m pytest llm_bench/python/who_what_benchmark/tests
98+
python -m pytest llm_bench/python/who_what_benchmark/tests

llm_bench/python/who_what_benchmark/README.md

+17-8
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,36 @@ metrics_per_prompt, metrics = evaluator.score(optimized_model, test_data=prompts
5555
```sh
5656
wwb --help
5757

58-
# run ground truth generation for uncompressed model on the first 32 samples from squad dataset
59-
# ground truth will be saved in llama_2_7b_squad_gt.csv file
58+
# Run ground truth generation for uncompressed model on the first 32 samples from squad dataset
59+
# Ground truth will be saved in llama_2_7b_squad_gt.csv file
6060
wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_squad_gt.csv --dataset squad --split validation[:32] --dataset-field question
6161

62-
# run comparison with compressed model on the first 32 samples from squad dataset
62+
# Run comparison with compressed model on the first 32 samples from squad dataset
6363
wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_squad_gt.csv --dataset squad --split validation[:32] --dataset-field question
6464

65-
# output will be like this
65+
# Output will be like this
6666
# similarity FDT SDT FDT norm SDT norm
6767
# 0 0.972823 67.296296 20.592593 0.735127 0.151505
6868

69-
# run ground truth generation for uncompressed model on internal set of questions
70-
# ground truth will be saved in llama_2_7b_squad_gt.csv file
69+
# Run ground truth generation for uncompressed model on internal set of questions
70+
# Ground truth will be saved in llama_2_7b_squad_gt.csv file
7171
wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv
7272

73-
# run comparison with compressed model on internal set of questions
73+
# Run comparison with compressed model on internal set of questions
7474
wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_wwb_gt.csv
7575

76-
## Control the number of samples and use verbose mode to see the difference in the results
76+
# Use --num-samples to control the number of samples
7777
wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --num-samples 10
78+
79+
# Use -v for verbose mode to see the difference in the results
7880
wwb --target-model /home/user/models/Llama_2_7b_chat_hf_int8 --gt-data llama_2_7b_wwb_gt.csv --num-samples 10 -v
81+
82+
# Use --hf AutoModelForCausalLM to instantiate the model from model_id/folder
83+
wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --hf
84+
85+
# Use --language parameter to control the language of promts
86+
# Autodetection works for basic Chinese models
87+
wwb --base-model meta-llama/Llama-2-7b-chat-hf --gt-data llama_2_7b_wwb_gt.csv --hf
7988
```
8089

8190
### Supported metrics
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
transformers>=4.35.2
22
sentence-transformers>=2.2.2
3-
openvino>=2023.3.0
4-
openvino-telemetry>=2023.2.1
3+
openvino>=2024.3.0
4+
openvino-telemetry>=2024.3.0
55
optimum-intel>=1.14
6+
openvino-tokenizers>=2024.3.0
7+
openvino-genai>=2024.3.0
68
pandas>=2.0.3
79
numpy>=1.23.5
810
tqdm>=4.66.1

llm_bench/python/who_what_benchmark/tests/test_cli.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,21 @@ def run_wwb(args):
3232

3333

3434
def setup_module():
35+
from optimum.exporters.openvino.convert import export_tokenizer
36+
3537
logger.info("Create models")
3638
tokenizer = AutoTokenizer.from_pretrained(model_id)
3739
base_model = OVModelForCausalLM.from_pretrained(model_id)
3840
base_model.save_pretrained(base_model_path)
3941
tokenizer.save_pretrained(base_model_path)
42+
export_tokenizer(tokenizer, base_model_path)
4043

4144
target_model = OVModelForCausalLM.from_pretrained(
4245
model_id, quantization_config=OVWeightQuantizationConfig(bits=8)
4346
)
4447
target_model.save_pretrained(target_model_path)
4548
tokenizer.save_pretrained(target_model_path)
49+
export_tokenizer(tokenizer, target_model_path)
4650

4751

4852
def teardown_module():
@@ -57,9 +61,10 @@ def test_target_model():
5761
"--num-samples", "2",
5862
"--device", "CPU"
5963
])
64+
6065
assert result.returncode == 0
61-
assert "Metrics for model" in result.stdout
62-
assert "## Reference text" not in result.stdout
66+
assert "Metrics for model" in result.stderr
67+
assert "## Reference text" not in result.stderr
6368

6469

6570
@pytest.fixture
@@ -76,8 +81,6 @@ def test_gt_data():
7681
"--num-samples", "2",
7782
"--device", "CPU"
7883
])
79-
import time
80-
time.sleep(1)
8184
data = pd.read_csv(temp_file_name)
8285
os.remove(temp_file_name)
8386

@@ -95,7 +98,7 @@ def test_output_directory():
9598
"--output", temp_dir
9699
])
97100
assert result.returncode == 0
98-
assert "Metrics for model" in result.stdout
101+
assert "Metrics for model" in result.stderr
99102
assert os.path.exists(os.path.join(temp_dir, "metrics_per_qustion.csv"))
100103
assert os.path.exists(os.path.join(temp_dir, "metrics.csv"))
101104

@@ -109,7 +112,7 @@ def test_verbose():
109112
"--verbose"
110113
])
111114
assert result.returncode == 0
112-
assert "## Reference text" in result.stdout
115+
assert "## Diff " in result.stderr
113116

114117

115118
def test_language_autodetect():
@@ -127,3 +130,34 @@ def test_language_autodetect():
127130

128131
assert result.returncode == 0
129132
assert "马克" in data["questions"].values[0]
133+
134+
135+
def test_hf_model():
136+
with tempfile.NamedTemporaryFile(suffix=".csv") as tmpfile:
137+
temp_file_name = tmpfile.name
138+
139+
result = run_wwb([
140+
"--base-model", model_id,
141+
"--gt-data", temp_file_name,
142+
"--num-samples", "2",
143+
"--device", "CPU",
144+
"--hf"
145+
])
146+
data = pd.read_csv(temp_file_name)
147+
os.remove(temp_file_name)
148+
149+
assert result.returncode == 0
150+
assert len(data["questions"].values) == 2
151+
152+
153+
def test_genai_model():
154+
result = run_wwb([
155+
"--base-model", base_model_path,
156+
"--target-model", target_model_path,
157+
"--num-samples", "2",
158+
"--device", "CPU",
159+
"--genai"
160+
])
161+
assert result.returncode == 0
162+
assert "Metrics for model" in result.stderr
163+
assert "## Reference text" not in result.stderr

llm_bench/python/who_what_benchmark/whowhatbench/evaluator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(
9696
max_new_tokens=128,
9797
crop_question=True,
9898
num_samples=None,
99-
language=None
99+
language=None,
100+
gen_answer_fn=None,
100101
) -> None:
101102
assert (
102103
base_model is not None or gt_data is not None
@@ -116,7 +117,7 @@ def __init__(
116117
self.language = autodetect_language(base_model)
117118

118119
if base_model:
119-
self.gt_data = self._generate_data(base_model)
120+
self.gt_data = self._generate_data(base_model, gen_answer_fn)
120121
else:
121122
self.gt_data = pd.read_csv(gt_data, keep_default_na=False)
122123

llm_bench/python/who_what_benchmark/whowhatbench/wwb.py

+65-13
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,61 @@
11
import argparse
22
import difflib
33
import os
4-
54
import json
65
import pandas as pd
6+
import logging
77
from datasets import load_dataset
88
from optimum.exporters import TasksManager
99
from optimum.intel.openvino import OVModelForCausalLM
1010
from optimum.utils import NormalizedConfigManager, NormalizedTextConfig
11-
from transformers import AutoConfig, AutoTokenizer
11+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
1212

1313
from . import Evaluator
1414

15+
# Configure logging
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
1519
TasksManager._SUPPORTED_MODEL_TYPE["stablelm-epoch"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"]
1620
NormalizedConfigManager._conf["stablelm-epoch"] = NormalizedTextConfig.with_args(
1721
num_layers="num_hidden_layers",
1822
num_attention_heads="num_attention_heads",
1923
)
2024

2125

22-
def load_model(model_id, device="CPU", ov_config=None):
26+
class GenAIModelWrapper():
27+
"""
28+
A helper class to store additional attributes for GenAI models
29+
"""
30+
def __init__(self, model, model_dir):
31+
self.model = model
32+
self.config = AutoConfig.from_pretrained(model_dir)
33+
34+
def __getattr__(self, attr):
35+
if attr in self.__dict__:
36+
return getattr(self, attr)
37+
else:
38+
return getattr(self.model, attr)
39+
40+
41+
def load_genai_pipeline(model_dir, device="CPU"):
42+
try:
43+
import openvino_genai
44+
except ImportError:
45+
logger.error("Failed to import openvino_genai package. Please install it.")
46+
exit(-1)
47+
logger.info("Using OpenVINO GenAI API")
48+
return GenAIModelWrapper(openvino_genai.LLMPipeline(model_dir, device), model_dir)
49+
50+
51+
def load_model(model_id, device="CPU", ov_config=None, use_hf=False, use_genai=False):
52+
if use_hf:
53+
logger.info("Using HF Transformers API")
54+
return AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, device_map=device.lower())
55+
56+
if use_genai:
57+
return load_genai_pipeline(model_id, device)
58+
2359
if ov_config:
2460
with open(ov_config) as f:
2561
ov_options = json.load(f)
@@ -157,6 +193,16 @@ def parse_args():
157193
default=None,
158194
help="Used to select default prompts based on the primary model language, e.g. 'en', 'ch'.",
159195
)
196+
parser.add_argument(
197+
"--hf",
198+
action="store_true",
199+
help="Use AutoModelForCausalLM from transformers library to instantiate the model.",
200+
)
201+
parser.add_argument(
202+
"--genai",
203+
action="store_true",
204+
help="Use LLMPipeline from transformers library to instantiate the model.",
205+
)
160206

161207
return parser.parse_args()
162208

@@ -211,6 +257,11 @@ def diff_strings(a: str, b: str, *, use_loguru_colors: bool = False) -> str:
211257
return "".join(output)
212258

213259

260+
def genai_gen_answer(model, tokenizer, question, max_new_tokens, skip_question):
261+
out = model.generate(question, max_new_tokens=max_new_tokens)
262+
return out
263+
264+
214265
def main():
215266
args = parse_args()
216267
check_args(args)
@@ -228,24 +279,25 @@ def main():
228279
language=args.language,
229280
)
230281
else:
231-
base_model = load_model(args.base_model, args.device, args.ov_config)
282+
base_model = load_model(args.base_model, args.device, args.ov_config, args.hf, args.genai)
232283
evaluator = Evaluator(
233284
base_model=base_model,
234285
test_data=prompts,
235286
tokenizer=tokenizer,
236287
similarity_model_id=args.text_encoder,
237288
num_samples=args.num_samples,
238289
language=args.language,
290+
gen_answer_fn=genai_gen_answer if args.genai else None
239291
)
240292
if args.gt_data:
241293
evaluator.dump_gt(args.gt_data)
242294
del base_model
243295

244296
if args.target_model:
245-
target_model = load_model(args.target_model, args.device, args.ov_config)
246-
all_metrics_per_question, all_metrics = evaluator.score(target_model)
247-
print("Metrics for model: ", args.target_model)
248-
print(all_metrics)
297+
target_model = load_model(args.target_model, args.device, args.ov_config, args.hf, args.genai)
298+
all_metrics_per_question, all_metrics = evaluator.score(target_model, genai_gen_answer if args.genai else None)
299+
logger.info("Metrics for model: %s", args.target_model)
300+
logger.info(all_metrics)
249301

250302
if args.output:
251303
if not os.path.exists(args.output):
@@ -269,11 +321,11 @@ def main():
269321
actual_text += l2 + "\n"
270322
diff += diff_strings(l1, l2) + "\n"
271323

272-
print("--------------------------------------------------------------------------------------")
273-
print("## Reference text {}:\n".format(i + 1), ref_text)
274-
print("## Actual text {}:\n".format(i + 1), actual_text)
275-
print("## Diff {}: ".format(i + 1))
276-
print(diff)
324+
logger.info("--------------------------------------------------------------------------------------")
325+
logger.info("## Reference text %d:\n%s", i + 1, ref_text)
326+
logger.info("## Actual text %d:\n%s", i + 1, actual_text)
327+
logger.info("## Diff %d: ", i + 1)
328+
logger.info(diff)
277329

278330

279331
if __name__ == "__main__":

0 commit comments

Comments
 (0)