@@ -43,25 +43,43 @@ def __init__(
43
43
self .text_emb_model = text_embeds_model
44
44
self .request = None
45
45
self .text_emb_request = None
46
+ self ._device = device
46
47
compile_only = kwargs .get ("compile_only" , False )
47
48
if compile_only :
48
49
self .text_emb_request = self .text_emb_model
49
50
self .request = self .model .create_infer_request ()
50
51
51
52
super ().__init__ (
52
- model , config , device , dynamic_shapes , ov_config , model_save_dir , quantization_config , ** kwargs
53
+ model = model ,
54
+ config = config ,
55
+ device = device ,
56
+ dynamic_shapes = dynamic_shapes ,
57
+ ov_config = ov_config ,
58
+ model_save_dir = model_save_dir ,
59
+ quantization_config = quantization_config ,
60
+ ** kwargs ,
53
61
)
54
62
55
63
def compile (self ):
56
64
if self .request is None :
57
- logger .info (f"Compiling the Language model to { self ._device } ..." )
58
- self .request = core .compile_model (self .model , self ._device , self .ov_config ).create_infer_request ()
65
+ if self ._compile_only :
66
+ self .request = self .model .create_infer_request ()
67
+ else :
68
+ logger .info (f"Compiling the Language model to { self ._device } ..." )
69
+ self .request = self ._compile_model (
70
+ self .model , self ._device , self .ov_config , self .model_save_dir
71
+ ).create_infer_request ()
59
72
self ._compile_text_emb ()
60
73
61
74
def _compile_text_emb (self ):
62
75
if self .text_emb_request is None :
63
- logger .info (f"Compiling the Text embeddings model to { self ._device } ..." )
64
- self .text_emb_request = core .compile_model (self .text_emb_model , self ._device , self .ov_config )
76
+ if self ._compile_only :
77
+ self .text_emb_request = self .text_emb_model
78
+ else :
79
+ logger .info (f"Compiling the Text embeddings model to { self ._device } ..." )
80
+ self .text_emb_request = self ._compile_model (
81
+ self .text_emb_model , self ._device , self .ov_config , self .model_save_dir
82
+ )
65
83
66
84
def clear_requests (self ):
67
85
if self ._compile_only :
@@ -258,14 +276,14 @@ def __init__(
258
276
self ._openvino_config = OVConfig (quantization_config = quantization_config )
259
277
self ._set_ov_config_parameters ()
260
278
self .language_model = OVModelWithEmbedForCausalLM (
261
- self .lm_model ,
262
- self .text_embdings_model ,
279
+ model = self .lm_model ,
280
+ text_embeds_model = self .text_embdings_model ,
263
281
config = config ,
264
- deivce = device ,
282
+ device = self . _device ,
265
283
ov_config = ov_config ,
266
284
model_save_dir = model_save_dir ,
267
285
quantization_config = quantization_config ,
268
- compile = not self ._compile_only ,
286
+ compile = self ._compile_only ,
269
287
compile_only = self ._compile_only ,
270
288
)
271
289
self .vision_embeddings = OVVisionEmbedding (self .vision_embeddings_model , self )
0 commit comments