Skip to content

Commit d6a323a

Browse files
committed
Deduplicated model wrapper code for optimum and vllm modeling. Disabled model serialization.
1 parent a34722d commit d6a323a

File tree

1 file changed

+47
-84
lines changed

1 file changed

+47
-84
lines changed

vllm/worker/model_runner.py

+47-84
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,47 @@ def flattenize_inputs(inputs):
4848
return flatten_inputs
4949

5050

51+
def ov_wrapper(self, *args, **kwargs):
52+
#print('OV FORWARD WRAPPER')
53+
#print(f'model class: {type(args[0])}')
54+
#for i, input in enumerate(args[1:]):
55+
# print(f'[{i}]: {type(input)}')
56+
#for key, value in kwargs.items():
57+
# print(f'{key}: {type(value)}')
58+
#result = args[0]._openvino_patch_orig_forward(*args[1:], **kwargs)
59+
input_metadata = kwargs['input_metadata']
60+
#print(dir(input_metadata))
61+
#print(input_metadata.is_prompt, input_metadata.slot_mapping, input_metadata.max_context_len, input_metadata.context_lens, input_metadata.block_tables)
62+
def prepare_data(t):
63+
t = np.array(t, copy=False)
64+
#print(t.__array_interface__['data'][0])
65+
assert t.flags["C_CONTIGUOUS"]
66+
return t
67+
flatten_kv_cache = flattenize_inputs(kwargs['kv_caches'])
68+
#total_size = sum([torch.numel(t) for t in flatten_kv_cache])
69+
#print(f'kv-cache total size: {total_size}')
70+
flatten_kv_cache = [prepare_data(t) for t in flatten_kv_cache]
71+
inputs = [
72+
kwargs['input_ids'],
73+
kwargs['positions'],
74+
*flatten_kv_cache,
75+
input_metadata.is_prompt, input_metadata.slot_mapping
76+
]
77+
#print('slot_mapping:', input_metadata.slot_mapping)
78+
if input_metadata.max_context_len is not None:
79+
# available from the second iteration
80+
inputs.append(input_metadata.max_context_len)
81+
inputs.append(input_metadata.context_lens)
82+
inputs.append(input_metadata.block_tables)
83+
else:
84+
inputs.append(np.array(0, dtype=np.int32)) # for optimum-based models this parameter can be used even on the first iteration
85+
#for input in inputs:
86+
# print(f'{input.dtype} wiht shape {input.shape}' if isinstance(input, torch.Tensor) else type(input))
87+
result = self.ov_request.infer(inputs, share_inputs=True, share_outputs=False)
88+
#print(f'result: {type(result)}')
89+
return torch.from_numpy(result[0])
90+
91+
5192
def patch_model_with_openvino(model, model_config, *model_args, **model_kwargs):
5293
if hasattr(model, '_openvino_patch_orig_forward'):
5394
return
@@ -240,7 +281,7 @@ def wrapper(module, target_op, *args, **kwargs):
240281
for out_name, out in zip(output_names, ov_model.outputs):
241282
out.get_tensor().set_names({out_name})
242283
ov_model.validate_nodes_and_infer_types()
243-
#ov.save_model(ov_model, "vllm_openvino_model.xml")
284+
#ov.serialize(ov_model, "vllm_openvino_model.xml")
244285
print('MODEL IS CONVERTED')
245286
#print(ov_model)
246287

@@ -249,50 +290,11 @@ def wrapper(module, target_op, *args, **kwargs):
249290
ov_config = {ov.properties.enable_profiling: True}
250291
# ov_config = {}
251292
ov_compiled = core.compile_model(ov_model, "CPU", config=ov_config)
252-
ov_request = ov_compiled.create_infer_request()
293+
model.ov_request = ov_compiled.create_infer_request()
253294

254295
from functools import partial
255-
def wrapper(*args, **kwargs):
256-
#print('OV FORWARD WRAPPER')
257-
#print(f'model class: {type(args[0])}')
258-
#for i, input in enumerate(args[1:]):
259-
# print(f'[{i}]: {type(input)}')
260-
#for key, value in kwargs.items():
261-
# print(f'{key}: {type(value)}')
262-
#result = args[0]._openvino_patch_orig_forward(*args[1:], **kwargs)
263-
input_metadata = kwargs['input_metadata']
264-
#print(dir(input_metadata))
265-
#print(input_metadata.is_prompt, input_metadata.slot_mapping, input_metadata.max_context_len, input_metadata.context_lens, input_metadata.block_tables)
266-
def prepare_data(t):
267-
t = np.array(t, copy=False)
268-
#print(t.__array_interface__['data'][0])
269-
assert t.flags["C_CONTIGUOUS"]
270-
return t
271-
flatten_kv_cache = flattenize_inputs(kwargs['kv_caches'])
272-
#total_size = sum([torch.numel(t) for t in flatten_kv_cache])
273-
#print(f'kv-cache total size: {total_size}')
274-
flatten_kv_cache = [prepare_data(t) for t in flatten_kv_cache]
275-
inputs = [
276-
kwargs['input_ids'],
277-
kwargs['positions'],
278-
*flatten_kv_cache,
279-
input_metadata.is_prompt, input_metadata.slot_mapping
280-
]
281-
#print('slot_mapping:', input_metadata.slot_mapping)
282-
if input_metadata.max_context_len is not None:
283-
# available from the second iteration
284-
inputs.append(input_metadata.max_context_len)
285-
inputs.append(input_metadata.context_lens)
286-
inputs.append(input_metadata.block_tables)
287-
else:
288-
inputs.append(np.array(0, dtype=np.int32)) # for optimum-based models this parameter can be used even on the first iteration
289-
#for input in inputs:
290-
# print(f'{input.dtype} wiht shape {input.shape}' if isinstance(input, torch.Tensor) else type(input))
291-
result = ov_request.infer(inputs, share_inputs=True, share_outputs=False)
292-
#print(f'result: {type(result)}')
293-
return torch.from_numpy(result[0])
294296
model._openvino_patch_orig_forward = model.forward
295-
model.forward = partial(wrapper, model)
297+
model.forward = partial(ov_wrapper, model)
296298

297299

298300
def patch_stateful_model(model):
@@ -548,53 +550,14 @@ def load_model(self) -> None:
548550
from optimum.intel import OVModelForCausalLM
549551
self.model = OVModelForCausalLM.from_pretrained(self.model_config.model, export=True, compile=False, load_in_8bit=False) # need stateful because it also enables SDPA
550552
patch_stateful_model(self.model.model)
551-
ov.serialize(self.model.model, 'vllm_openvino_model.xml')
553+
#ov.serialize(self.model.model, 'vllm_openvino_model.xml')
552554
core = ov.Core()
553555
ov_compiled = core.compile_model(self.model.model, "CPU")
554-
ov_request = ov_compiled.create_infer_request()
556+
self.model.ov_request = ov_compiled.create_infer_request()
555557

556558
from functools import partial
557-
def wrapper(*args, **kwargs):
558-
#print('OV FORWARD WRAPPER')
559-
#print(f'model class: {type(args[0])}')
560-
#for i, input in enumerate(args[1:]):
561-
# print(f'[{i}]: {type(input)}')
562-
#for key, value in kwargs.items():
563-
# print(f'{key}: {type(value)}')
564-
#result = args[0]._openvino_patch_orig_forward(*args[1:], **kwargs)
565-
input_metadata = kwargs['input_metadata']
566-
#print(dir(input_metadata))
567-
#print(input_metadata.is_prompt, input_metadata.slot_mapping, input_metadata.max_context_len, input_metadata.context_lens, input_metadata.block_tables)
568-
def prepare_data(t):
569-
t = np.array(t, copy=False)
570-
#print(t.__array_interface__['data'][0])
571-
assert t.flags["C_CONTIGUOUS"]
572-
return t
573-
flatten_kv_cache = flattenize_inputs(kwargs['kv_caches'])
574-
#total_size = sum([torch.numel(t) for t in flatten_kv_cache])
575-
#print(f'kv-cache total size: {total_size}')
576-
flatten_kv_cache = [prepare_data(t) for t in flatten_kv_cache]
577-
inputs = [
578-
kwargs['input_ids'],
579-
kwargs['positions'],
580-
*flatten_kv_cache,
581-
input_metadata.is_prompt, input_metadata.slot_mapping
582-
]
583-
#print('slot_mapping:', input_metadata.slot_mapping)
584-
if input_metadata.max_context_len is not None:
585-
# available from the second iteration
586-
inputs.append(input_metadata.max_context_len)
587-
inputs.append(input_metadata.context_lens)
588-
inputs.append(input_metadata.block_tables)
589-
else:
590-
inputs.append(np.array(0, dtype=np.int32)) # for optimum-based models this parameter can be used even on the first iteration
591-
#for input in inputs:
592-
# print(f'{input.dtype} wiht shape {input.shape}' if isinstance(input, torch.Tensor) else type(input))
593-
result = ov_request.infer(inputs, share_inputs=True, share_outputs=False)
594-
#print(f'result: {type(result)}')
595-
return torch.from_numpy(result[0])
596559
self.model._openvino_patch_orig_forward = self.model.forward
597-
self.model.forward = partial(wrapper, self.model)
560+
self.model.forward = partial(ov_wrapper, self.model)
598561

599562
# self.vllm_model = get_model(self.model_config)
600563
# def sample_wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)