Skip to content

Commit 23088ba

Browse files
committedFeb 7, 2024
Switched to pre-built PagedAttentionExtension from openvinotoolkit/openvino_contrib#867. Minimized debug output.

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed
 

‎vllm/worker/model_runner.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,26 @@ def wrapper(module, target_op, *args, **kwargs):
196196
extension=[
197197
ModuleExtension(
198198
PagedAttention,
199-
extension=lambda module: 'PagedAttentionPlaceholder',
199+
extension=lambda module: 'PagedAttentionExtension',
200200
replacer=lambda module, *args, **kwargs: args[0],
201201
wrapper=wrapper
202-
)
202+
),
203+
'libuser_ov_extensions.so'
203204
]
204205
)
205206

206-
for input_name, input_data, input_tensor in zip(input_names, flatten_input, ov_model.inputs):
207+
for input_data, input_tensor in zip(flatten_input, ov_model.inputs):
207208
if input_tensor.element_type.is_dynamic():
208209
input_tensor.get_node().set_element_type(ov_dtype_maping[input_data.dtype])
209210
if input_tensor.partial_shape.rank.is_dynamic:
210211
input_tensor.get_node().set_partial_shape(ov.PartialShape([-1]*input_data.ndim))
211-
#input_tensor.get_tensor().set_names({input_name})
212212

213213
for out_name, out in zip(output_names, ov_model.outputs):
214214
out.get_tensor().set_names({out_name})
215215
ov_model.validate_nodes_and_infer_types()
216216
#ov.save_model(ov_model, "vllm_openvino_model.xml")
217217
print('>>>>>>>>>>>>> OV MODEL CONVERTED')
218-
print(ov_model)
218+
#print(ov_model)
219219
ov_compiled = ov.compile_model(ov_model)
220220

221221
from functools import partial
@@ -243,6 +243,7 @@ def wrapper(*args, **kwargs):
243243
inputs.append(input_metadata.block_tables)
244244
#for input in inputs:
245245
# print(f'{input.dtype} wiht shape {input.shape}' if isinstance(input, torch.Tensor) else type(input))
246+
#print('input_metadata.slot_mapping:', input_metadata.slot_mapping)
246247
result = ov_compiled(inputs, share_outputs=False)
247248
#print(f'result: {type(result)}')
248249
return torch.from_numpy(result[0])

0 commit comments

Comments
 (0)
Please sign in to comment.