Skip to content

Commit 6bf5fbc

Browse files
Expose InferRequestWrapper class so it can be imported from elsewhere (#533)
* Expose InferRequestWrapper class so it can be imported from elsewhere * Fix
1 parent 87b36db commit 6bf5fbc

File tree

1 file changed

+39
-35
lines changed

1 file changed

+39
-35
lines changed

optimum/intel/openvino/quantization.py

+39-35
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,44 @@ def batch_size(self):
7777
return batch_size
7878

7979

80+
class InferRequestWrapper:
81+
def __init__(self, request, data_cache=None):
82+
self.request = request
83+
if data_cache is None:
84+
data_cache = []
85+
self.data_cache = data_cache
86+
87+
def __call__(self, *args, **kwargs):
88+
self.data_cache.append(*args)
89+
return self.request(*args, **kwargs)
90+
91+
def infer(self, inputs: Any = None, share_inputs: bool = False):
92+
self.data_cache.append(inputs)
93+
return self.request.infer(inputs, share_inputs)
94+
95+
def start_async(
96+
self,
97+
inputs: Any = None,
98+
userdata: Any = None,
99+
share_inputs: bool = False,
100+
*,
101+
shared_memory: Any = None,
102+
):
103+
self.data_cache.append(inputs)
104+
self.request.infer(inputs, share_inputs, share_outputs=True)
105+
106+
def wait(self):
107+
pass
108+
109+
def get_tensor(self, name: str):
110+
return Tensor(self.request.results[name])
111+
112+
def __getattr__(self, attr):
113+
if attr in self.__dict__:
114+
return getattr(self, attr)
115+
return getattr(self.request, attr)
116+
117+
80118
class OVQuantizer(OptimumQuantizer):
81119
"""
82120
Handle the NNCF quantization process.
@@ -297,41 +335,7 @@ def _quantize_ovcausallm(
297335
subset_size = kwargs.get("subset_size", 300)
298336
data_cache = []
299337

300-
class InferRequestWrapper:
301-
def __init__(self, request):
302-
self.request = request
303-
304-
def __call__(self, *args, **kwargs):
305-
data_cache.append(*args)
306-
return self.request(*args, **kwargs)
307-
308-
def infer(self, inputs: Any = None, share_inputs: bool = False):
309-
data_cache.append(inputs)
310-
return self.request.infer(inputs, share_inputs)
311-
312-
def start_async(
313-
self,
314-
inputs: Any = None,
315-
userdata: Any = None,
316-
share_inputs: bool = False,
317-
*,
318-
shared_memory: Any = None,
319-
):
320-
data_cache.append(inputs)
321-
self.request.infer(inputs, share_inputs, share_outputs=True)
322-
323-
def wait(self):
324-
pass
325-
326-
def get_tensor(self, name: str):
327-
return Tensor(self.request.results[name])
328-
329-
def __getattr__(self, attr):
330-
if attr in self.__dict__:
331-
return getattr(self, attr)
332-
return getattr(self.request, attr)
333-
334-
self.model.request = InferRequestWrapper(self.model.request)
338+
self.model.request = InferRequestWrapper(self.model.request, data_cache)
335339
for _, data in enumerate(calibration_dataloader):
336340
self.model.generate(**data, max_new_tokens=1)
337341
if len(data_cache) >= subset_size:

0 commit comments

Comments
 (0)