@@ -77,6 +77,44 @@ def batch_size(self):
77
77
return batch_size
78
78
79
79
80
+ class InferRequestWrapper :
81
+ def __init__ (self , request , data_cache = None ):
82
+ self .request = request
83
+ if data_cache is None :
84
+ data_cache = []
85
+ self .data_cache = data_cache
86
+
87
+ def __call__ (self , * args , ** kwargs ):
88
+ self .data_cache .append (* args )
89
+ return self .request (* args , ** kwargs )
90
+
91
+ def infer (self , inputs : Any = None , share_inputs : bool = False ):
92
+ self .data_cache .append (inputs )
93
+ return self .request .infer (inputs , share_inputs )
94
+
95
+ def start_async (
96
+ self ,
97
+ inputs : Any = None ,
98
+ userdata : Any = None ,
99
+ share_inputs : bool = False ,
100
+ * ,
101
+ shared_memory : Any = None ,
102
+ ):
103
+ self .data_cache .append (inputs )
104
+ self .request .infer (inputs , share_inputs , share_outputs = True )
105
+
106
+ def wait (self ):
107
+ pass
108
+
109
+ def get_tensor (self , name : str ):
110
+ return Tensor (self .request .results [name ])
111
+
112
+ def __getattr__ (self , attr ):
113
+ if attr in self .__dict__ :
114
+ return getattr (self , attr )
115
+ return getattr (self .request , attr )
116
+
117
+
80
118
class OVQuantizer (OptimumQuantizer ):
81
119
"""
82
120
Handle the NNCF quantization process.
@@ -297,41 +335,7 @@ def _quantize_ovcausallm(
297
335
subset_size = kwargs .get ("subset_size" , 300 )
298
336
data_cache = []
299
337
300
- class InferRequestWrapper :
301
- def __init__ (self , request ):
302
- self .request = request
303
-
304
- def __call__ (self , * args , ** kwargs ):
305
- data_cache .append (* args )
306
- return self .request (* args , ** kwargs )
307
-
308
- def infer (self , inputs : Any = None , share_inputs : bool = False ):
309
- data_cache .append (inputs )
310
- return self .request .infer (inputs , share_inputs )
311
-
312
- def start_async (
313
- self ,
314
- inputs : Any = None ,
315
- userdata : Any = None ,
316
- share_inputs : bool = False ,
317
- * ,
318
- shared_memory : Any = None ,
319
- ):
320
- data_cache .append (inputs )
321
- self .request .infer (inputs , share_inputs , share_outputs = True )
322
-
323
- def wait (self ):
324
- pass
325
-
326
- def get_tensor (self , name : str ):
327
- return Tensor (self .request .results [name ])
328
-
329
- def __getattr__ (self , attr ):
330
- if attr in self .__dict__ :
331
- return getattr (self , attr )
332
- return getattr (self .request , attr )
333
-
334
- self .model .request = InferRequestWrapper (self .model .request )
338
+ self .model .request = InferRequestWrapper (self .model .request , data_cache )
335
339
for _ , data in enumerate (calibration_dataloader ):
336
340
self .model .generate (** data , max_new_tokens = 1 )
337
341
if len (data_cache ) >= subset_size :
0 commit comments