Skip to content

Commit ee8cf38

Browse files
authored
New QWEN 2 VLM (#3247)
1 parent e2e7270 commit ee8cf38

File tree

7 files changed

+232
-8
lines changed

7 files changed

+232
-8
lines changed

setup.cfg

+4
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ vlm =
204204
# For metrics
205205
pycocoevalcap~=1.2
206206

207+
# For Qwen2
208+
transformers~=4.45.2
209+
qwen-vl-utils~=0.0.8
210+
207211
ibm-enterprise-scenarios =
208212
openpyxl~=3.1
209213

src/helm/benchmark/run_spec_factory.py

+7
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec:
138138
):
139139
run_spec = singleton(IncreaseMaxTokensRunExpander(value=1).expand(run_spec))
140140

141+
if model.name == "openai/o1-2024-12-17":
142+
# From https://platform.openai.com/docs/guides/reasoning,
143+
# "OpenAI recommends reserving at least 25,000 tokens for reasoning and outputs when you start
144+
# experimenting with these models. As you become familiar with the number of reasoning tokens your
145+
# prompts require, you can adjust this buffer accordingly."
146+
run_spec = singleton(IncreaseMaxTokensRunExpander(value=25_000).expand(run_spec))
147+
141148
# IDEFICS special handling
142149
if IDEFICS_MODEL_TAG in model.tags:
143150
if IDEFICS_INSTRUCT_MODEL_TAG in model.tags:

src/helm/benchmark/static/schema_vhelm.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ run_groups:
726726
- accuracy
727727
- general_information
728728
environment:
729-
main_name: exact_match
729+
main_name: quasi_prefix_exact_match
730730
main_split: test
731731
taxonomy:
732732
task: short-answer question answering
@@ -902,7 +902,7 @@ run_groups:
902902
- accuracy
903903
- general_information
904904
environment:
905-
main_name: exact_match
905+
main_name: quasi_prefix_exact_match
906906
main_split: test
907907
taxonomy:
908908
task: short-answer question answering

src/helm/clients/openai_client.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class OpenAIClient(CachingClient):
2828

2929
# Error OpenAI throws when the image in the prompt violates their content policy
3030
INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
31+
INAPPROPRIATE_PROMPT_ERROR: str = "Invalid prompt: your prompt was flagged"
3132

3233
# Set the finish reason to this if the prompt violates OpenAI's content policy
3334
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
@@ -171,11 +172,6 @@ def _make_chat_request(self, request: Request) -> RequestResult:
171172
"frequency_penalty": request.frequency_penalty,
172173
}
173174

174-
# OpenAI's vision API doesn't allow None values for stop.
175-
# Fails with "body -> stop: none is not an allowed value" if None is passed.
176-
if is_vlm(request.model) and raw_request["stop"] is None:
177-
raw_request.pop("stop")
178-
179175
# Special handling for o1 models.
180176
# Refer to the "Reasoning models" documentation further discussion of o1 model limitations:
181177
# https://platform.openai.com/docs/guides/reasoning
@@ -191,6 +187,18 @@ def _make_chat_request(self, request: Request) -> RequestResult:
191187
if raw_request["stop"] is None:
192188
raw_request.pop("stop")
193189

190+
if request.model_engine == "o1-2024-12-17":
191+
# Avoid error:
192+
# "Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is
193+
# not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature',
194+
# 'code': 'unsupported_parameter'}}"
195+
raw_request.pop("temperature", None)
196+
elif is_vlm(request.model):
197+
# Avoid error:
198+
# "Invalid type for 'stop': expected an unsupported value, but got null instead."
199+
if raw_request["stop"] is None:
200+
raw_request.pop("stop")
201+
194202
# Special handling for gpt-4o-audio-preview
195203
# See: https://platform.openai.com/docs/guides/audio
196204
if request.model_engine.startswith("gpt-4o-audio-preview"):
@@ -208,7 +216,7 @@ def do_it() -> Dict[str, Any]:
208216
cache_key = self._get_cache_key(raw_request, request)
209217
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
210218
except openai.OpenAIError as e:
211-
if self.INAPPROPRIATE_IMAGE_ERROR in str(e):
219+
if self.INAPPROPRIATE_IMAGE_ERROR in str(e) or self.INAPPROPRIATE_PROMPT_ERROR in str(e):
212220
hlog(f"Failed safety check: {str(request)}")
213221
empty_completion = GeneratedOutput(
214222
text="",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from threading import Lock
2+
from typing import Any, Dict, List, Optional
3+
from dataclasses import dataclass
4+
5+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6+
from qwen_vl_utils import process_vision_info
7+
import torch
8+
9+
from helm.common.cache import CacheConfig
10+
from helm.common.gpu_utils import get_torch_device_name
11+
from helm.common.hierarchical_logger import hlog, htrack_block
12+
from helm.common.media_object import TEXT_TYPE
13+
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
14+
from helm.common.request import wrap_request_time
15+
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt
16+
17+
18+
@dataclass(frozen=True)
19+
class LoadedQwen2ModelProcessor:
20+
model: Qwen2VLForConditionalGeneration
21+
processor: AutoProcessor
22+
23+
24+
_models_lock: Lock = Lock()
25+
_models: Dict[str, Optional[LoadedQwen2ModelProcessor]] = {
26+
"Qwen/Qwen2-VL-7B-Instruct": None,
27+
"Qwen/Qwen2-VL-72B-Instruct": None,
28+
}
29+
30+
31+
class Qwen2VLMClient(CachingClient):
32+
def __init__(self, cache_config: CacheConfig):
33+
super().__init__(cache_config=cache_config)
34+
self._device: str = get_torch_device_name()
35+
36+
def _get_model_name(self, helm_model_name: str) -> str:
37+
if helm_model_name == "qwen2-vl-7b-instruct":
38+
return "Qwen/Qwen2-VL-7B-Instruct"
39+
elif helm_model_name == "qwen2-vl-72b-instruct":
40+
return "Qwen/Qwen2-VL-72B-Instruct"
41+
else:
42+
raise ValueError(f"Unhandled model name: {helm_model_name}")
43+
44+
def _get_model(self, helm_model_name: str) -> LoadedQwen2ModelProcessor:
45+
global _models_lock
46+
global _models
47+
48+
model_name = self._get_model_name(helm_model_name)
49+
50+
with _models_lock:
51+
loaded = _models[model_name]
52+
if loaded is None:
53+
hlog(f"Loading model {model_name} and caching in memory...")
54+
# https://huggingface.co/docs/transformers/model_doc/qwen2_vl#flash-attention-2-to-speed-up-generation
55+
model = Qwen2VLForConditionalGeneration.from_pretrained(
56+
model_name,
57+
torch_dtype=torch.bfloat16,
58+
device_map="auto",
59+
attn_implementation="flash_attention_2",
60+
).eval()
61+
processor = AutoProcessor.from_pretrained(model_name)
62+
loaded = LoadedQwen2ModelProcessor(model=model, processor=processor)
63+
_models[model_name] = loaded
64+
65+
return loaded
66+
67+
def make_request(self, request: Request) -> RequestResult:
68+
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
69+
loaded = self._get_model(request.model_engine)
70+
model = loaded.model
71+
processor = loaded.processor
72+
73+
# Build Qwen2 messages
74+
# We assume all media objects go into a single "user" message:
75+
# messages = [
76+
# {
77+
# "role": "user",
78+
# "content": [
79+
# {"type": "image", "image": "file:///path/to/image1.jpg"},
80+
# {"type": "image", "image": "file:///path/to/image2.jpg"},
81+
# {"type": "text", "text": "Describe these images."}
82+
# ]
83+
# }
84+
# ]
85+
message_content = []
86+
for media_object in request.multimodal_prompt.media_objects:
87+
if media_object.is_type("image") and media_object.location:
88+
message_content.append({"type": "image", "image": media_object.location})
89+
elif media_object.is_type(TEXT_TYPE):
90+
if media_object.text is None:
91+
raise ValueError("MediaObject of text type has missing text field value")
92+
message_content.append({"type": "text", "text": media_object.text})
93+
else:
94+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
95+
96+
messages = [{"role": "user", "content": message_content}]
97+
98+
# Prepare text and vision inputs
99+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
100+
image_inputs, video_inputs = process_vision_info(messages)
101+
102+
inputs = processor(
103+
text=[text],
104+
images=image_inputs,
105+
videos=video_inputs,
106+
padding=True,
107+
return_tensors="pt",
108+
).to(self._device)
109+
110+
generation_args = {
111+
"max_new_tokens": request.max_tokens,
112+
}
113+
114+
completions: List[GeneratedOutput] = []
115+
request_time: float = 0
116+
request_datetime: Optional[int] = None
117+
all_cached: bool = True
118+
119+
with htrack_block(f"Generating for prompt: {text}"):
120+
for completion_index in range(request.num_completions):
121+
try:
122+
123+
def do_it() -> Dict[str, Any]:
124+
generated_ids = model.generate(**inputs, **generation_args)
125+
# Remove the input prefix from outputs
126+
generated_ids_trimmed = [
127+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
128+
]
129+
output_text = processor.batch_decode(
130+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
131+
)
132+
# There's only one batch element
133+
completion = output_text[0]
134+
# For simplicity, we split tokens by whitespace.
135+
# A more accurate tokenization would require a tokenizer for Qwen2, if desired.
136+
tokens = completion.split()
137+
return {"output": (completion, tokens)}
138+
139+
cache_key = CachingClient.make_cache_key(
140+
raw_request={
141+
"completion_index": completion_index,
142+
"model": request.model,
143+
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
144+
**generation_args,
145+
},
146+
request=request,
147+
)
148+
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
149+
except RuntimeError as model_error:
150+
return RequestResult(
151+
success=False, cached=False, error=str(model_error), completions=[], embedding=[]
152+
)
153+
154+
text_out, tokens = result["output"]
155+
completions.append(
156+
GeneratedOutput(
157+
text=text_out,
158+
logprob=0,
159+
tokens=[Token(text=str(token), logprob=0) for token in tokens],
160+
)
161+
)
162+
hlog(f"Generated: {text_out}")
163+
164+
request_time += result["request_time"]
165+
request_datetime = request_datetime or result.get("request_datetime")
166+
all_cached = all_cached and cached
167+
168+
return RequestResult(
169+
success=True,
170+
cached=all_cached,
171+
request_time=request_time,
172+
request_datetime=request_datetime,
173+
completions=completions,
174+
embedding=[],
175+
)

src/helm/config/model_deployments.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,20 @@ model_deployments:
27332733
client_spec:
27342734
class_name: "helm.clients.vision_language.qwen_vlm_client.QwenVLMClient"
27352735

2736+
- name: huggingface/qwen2-vl-7b-instruct
2737+
model_name: qwen/qwen2-vl-7b-instruct
2738+
tokenizer_name: qwen/qwen-vl-chat
2739+
max_sequence_length: 8191
2740+
client_spec:
2741+
class_name: "helm.clients.vision_language.qwen2_vlm_client.Qwen2VLMClient"
2742+
2743+
- name: huggingface/qwen2-vl-72b-instruct
2744+
model_name: qwen/qwen2-vl-72b-instruct
2745+
tokenizer_name: qwen/qwen-vl-chat
2746+
max_sequence_length: 8191
2747+
client_spec:
2748+
class_name: "helm.clients.vision_language.qwen2_vlm_client.Qwen2VLMClient"
2749+
27362750
- name: huggingface/qwen-audio-chat
27372751
model_name: qwen/qwen-audio-chat
27382752
tokenizer_name: qwen/qwen-audio-chat

src/helm/config/model_metadata.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -2827,6 +2827,22 @@ models:
28272827
release_date: 2023-08-24
28282828
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]
28292829

2830+
- name: qwen/qwen2-vl-7b-instruct
2831+
display_name: Qwen2-VL Instruct (7B)
2832+
description: The second generation of Qwen2-VL models ([paper](https://arxiv.org/abs/2409.12191)).
2833+
creator_organization_name: Alibaba Group
2834+
access: open
2835+
release_date: 2024-08-29
2836+
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]
2837+
2838+
- name: qwen/qwen2-vl-72b-instruct
2839+
display_name: Qwen2-VL Instruct (72B)
2840+
description: The second generation of Qwen2-VL models ([paper](https://arxiv.org/abs/2409.12191)).
2841+
creator_organization_name: Alibaba Group
2842+
access: open
2843+
release_date: 2024-08-29
2844+
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]
2845+
28302846
- name: qwen/qwen-audio-chat
28312847
display_name: Qwen-Audio Chat
28322848
description: Auditory multimodal version of the Qwen large language model series ([paper](https://arxiv.org/abs/2311.07919)).

0 commit comments

Comments
 (0)