@@ -41,9 +41,6 @@ class Qwen2AudioLMClient(CachingClient):
41
41
"""
42
42
43
43
END_OF_TEXT_TOKEN : str = "<|im_end|>"
44
- # The official recommendation is to set the prefix length to 256
45
- # https://huggingface.co/Qwen/Qwen2-Audio-7B-Instruct
46
- PREFIX_TOKEN_LENGTH : int = 256
47
44
48
45
def __init__ (self , cache_config : CacheConfig ):
49
46
super ().__init__ (cache_config = cache_config )
@@ -84,11 +81,6 @@ def make_request(self, request: Request) -> RequestResult:
84
81
model = loaded_model_processor .model
85
82
tokenizer = loaded_model_processor .tokenizer
86
83
87
- # Qwen2-Audio-Instruct counts input into the max_length, so we need to add the length of the prompt
88
- generation_args = {
89
- "max_length" : request .max_tokens + self .PREFIX_TOKEN_LENGTH ,
90
- }
91
-
92
84
input_query : List [Dict [str , Any ]] = []
93
85
query : List [Dict [str , str ]] = []
94
86
prompt_text : str = ""
@@ -142,10 +134,15 @@ def do_it() -> Dict[str, Any]:
142
134
return_tensors = "pt" ,
143
135
padding = True ,
144
136
)
137
+ input_length = inputs .input_ids .size (1 )
138
+ # Qwen2-Audio-Instruct counts input into the max_length,
139
+ # so we need to add the length of the prompt
145
140
inputs = inputs .to (self ._device )
146
- pred = model .generate (** inputs , ** generation_args )
147
- completion = tokenizer .decode (pred .cpu ()[0 ], skip_special_tokens = False )
141
+ pred = model .generate (** inputs , max_length = request .max_tokens + input_length )[:, input_length :]
148
142
143
+ completion = tokenizer .decode (
144
+ pred .cpu ()[0 ], skip_special_tokens = True , clean_up_tokenization_spaces = False
145
+ )
149
146
# The processor of Qwen2-Audio-Instruct consists an AutoTokenizer and a WhisperFeatureExtractor
150
147
tokens : List [str ] = tokenizer .tokenizer .tokenize (completion )
151
148
return {"output" : (completion , tokens )}
@@ -156,7 +153,7 @@ def do_it() -> Dict[str, Any]:
156
153
"completion_index" : completion_index ,
157
154
"model" : request .model ,
158
155
"prompt" : generate_uid_for_multimodal_prompt (request .multimodal_prompt ),
159
- ** generation_args ,
156
+ "max_tokens" : request . max_tokens ,
160
157
},
161
158
request = request ,
162
159
)
@@ -167,11 +164,7 @@ def do_it() -> Dict[str, Any]:
167
164
)
168
165
169
166
text , tokens = result ["output" ]
170
-
171
- # Truncate the output text as the original Qwen includes the prompt in the output sequence
172
- text = text [len (prompt_text ) :]
173
- text = text .replace (self .END_OF_TEXT_TOKEN , "" )
174
- hlog (f"Truncated: { text } " )
167
+ hlog (f"Generated: { text } " )
175
168
176
169
# Tokenize truncated text to get the list of tokens
177
170
completions .append (
0 commit comments