Skip to content

Commit d16ca55

Browse files
authored
Add IPEX inference_mode contextmanager to enable optimization on Intel platform. (#125)
* Add IPEX dependency * Initial support for IPEX inference mode. * Added docker image * Force float32 for now with kernel selection. * Implement default fallback in case of Exception for optimized model. * Move IPEx to optional dependency * Simplify the usage of inference_mode by forcing usage of oneDNN * Enable the use of AMP for bfloat16 * Added documentation. * Style. * Making sure we are not importing ipex if not available.
1 parent bcf3c33 commit d16ca55

File tree

5 files changed

+180
-0
lines changed

5 files changed

+180
-0
lines changed

docker/Dockerfile.intel

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# syntax = docker/dockerfile:1
2+
# based onhttps://github.com/pytorch/pytorch/blob/master/Dockerfile
3+
#
4+
# NOTE: To build this you will need a docker version >= 19.03 and DOCKER_BUILDKIT=1
5+
#
6+
# If you do not use buildkit you are not going to have a good time
7+
#
8+
# For reference:
9+
# https://docs.docker.com/develop/develop-images/build_enhancements/
10+
11+
ARG BASE_IMAGE=ubuntu:22.04
12+
FROM ${BASE_IMAGE} AS dev-base
13+
RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \
14+
apt-get update && \
15+
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
16+
ca-certificates \
17+
git \
18+
curl \
19+
vim \
20+
build-essential \
21+
ccache \
22+
libgoogle-perftools-dev \
23+
numactl \
24+
cmake \
25+
libjpeg-dev \
26+
pybind11-dev \
27+
libpng-dev \
28+
pybind11-dev \
29+
&& rm -rf /var/lib/apt/lists/*
30+
RUN /usr/sbin/update-ccache-symlinks
31+
RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache
32+
ENV PATH /opt/conda/bin:$PATH
33+
34+
FROM dev-base as conda
35+
ARG PYTHON_VERSION=3.10
36+
RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
37+
chmod +x ~/miniconda.sh && \
38+
~/miniconda.sh -b -p /opt/conda && \
39+
rm ~/miniconda.sh && \
40+
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda-build pyyaml numpy ipython mkl mkl-include ninja cython typing pybind11 Pillow && \
41+
/opt/conda/bin/conda clean -ya
42+
43+
FROM dev-base AS build
44+
ARG IPEX_VERSION=v1.13.0
45+
ARG PYTORCH_VERSION=v1.13.0
46+
ARG TORCHVISION_VERSION=0.13.0+cpu
47+
ARG TORCHAUDIO_VERSION=0.13.0+cpu
48+
COPY --from=conda /opt/conda /opt/conda
49+
RUN --mount=type=cache,target=/opt/ccache \
50+
python -m pip install --no-cache-dir torch==${PYTORCH_VERSION}+cpu torchvision==${TORCHVISION_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html && \
51+
git clone https://github.com/intel/intel-extension-for-pytorch && \
52+
cd intel-extension-for-pytorch && \
53+
git checkout ${IPEX_VERSION} && \
54+
git submodule sync && \
55+
git submodule update --init --recursive && \
56+
python -m pip install --no-cache-dir -r requirements.txt && \
57+
python setup.py bdist_wheel && \
58+
python -m pip install --no-cache-dir dist/*.whl && \
59+
cd .. && rm -rf intel-extension-for-pytorch
60+
61+
FROM dev-base as dev
62+
COPY --from=build /opt/conda /opt/conda
63+
ARG OMP_NUM_THREADS=1
64+
ENV OMP_NUM_THREADS ${OMP_NUM_THREADS}
65+
ARG KMP_BLOCKTIME=1
66+
ENV KMP_BLOCKTIME ${KMP_BLOCKTIME}
67+
ARG KMP_HW_SUBSET=1T
68+
ENV KMP_HW_SUBSET ${KMP_HW_SUBSET}
69+
ENV LD_PRELOAD "/opt/conda/lib/libiomp5.so /usr/lib/x86_64-linux-gnu/libtcmalloc.so"
70+
ENV LD_LIBRARY_PATH "/opt/conda/lib/python3.8/site-packages/lib/"

optimum/intel/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .ipex import inference_mode
1516
from .version import __version__

optimum/intel/ipex/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .inference import inference_mode

optimum/intel/ipex/inference.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Union
2+
3+
import torch
4+
from torch import nn
5+
from transformers import add_start_docstrings
6+
from transformers.pipelines import Pipeline
7+
from transformers.utils import is_ipex_available
8+
9+
10+
IPEX_NOT_AVAILABLE_ERROR_MSG = (
11+
"Intel PyTorch Extensions was not found."
12+
"please make sure you've installed the package or run "
13+
"pip install intel_extension_for_pytorch"
14+
)
15+
16+
if is_ipex_available():
17+
import intel_extension_for_pytorch as ipex
18+
19+
20+
class _ModelFallbackWrapper:
21+
22+
__slots__ = ("_optimized", "_default")
23+
24+
def __init__(self, optimized, default):
25+
self._optimized = optimized
26+
self._default = default
27+
28+
def __call__(self, *args, **kwargs):
29+
try:
30+
return self._optimized(*args, **kwargs)
31+
except Exception:
32+
return self._default(*args, **kwargs)
33+
34+
def __getattr__(self, item):
35+
if not item.startswith("__"):
36+
return getattr(self._default, item)
37+
else:
38+
return self.item
39+
40+
41+
@add_start_docstrings(
42+
"""
43+
inference_mode is an Intel specific context-manager analogous to PyTorch's inference_mode to use for inference
44+
workload on Intel CPUs, especially Intel Xeon Scalable CPUs.
45+
""",
46+
)
47+
class inference_mode:
48+
__slots__ = ("_model", "_dtype", "_graph_mode", "_verbose", "_original")
49+
50+
def __init__(self, model: Union[nn.Module, Pipeline], dtype: torch.dtype = torch.float32, verbose: bool = False):
51+
"""
52+
Args:
53+
model (`torch.nn.Module` or `transformers.Pipeline`):
54+
The model or pipeline instance to optimize.
55+
dtype (`torch.dtype = torch.float32`), *optional*):
56+
The data type used to do the computation.
57+
Acceptable type are `torch.float32` (default) and `torch.bfloat16`.
58+
Please note `torch.bfloat16` requires `avx512_bf16` instructions set as present on
59+
4th Generation of Intel Xeon Scalable CPUs (Sapphire Rapids).
60+
verbose (`boolean = False`, *optional*):
61+
Enable IPEx verbose output to see the kernels and optimizations applied.
62+
"""
63+
if not is_ipex_available():
64+
raise ImportError(IPEX_NOT_AVAILABLE_ERROR_MSG)
65+
66+
self._model = model
67+
self._verbose = ipex.utils.verbose.VERBOSE_ON if verbose else ipex.utils.verbose.VERBOSE_OFF
68+
self._dtype = dtype
69+
self._graph_mode = False # Let's keep for future use when it doesn't hang anymore
70+
self._original = None
71+
72+
def __enter__(self):
73+
with torch.inference_mode():
74+
with ipex.verbose(self._verbose):
75+
ipex.enable_onednn_fusion(True)
76+
if isinstance(self._model, Pipeline):
77+
self._original = self._model.model
78+
79+
model = ipex.optimize(
80+
self._model.model,
81+
dtype=self._dtype,
82+
graph_mode=self._graph_mode,
83+
level="O1",
84+
auto_kernel_selection=True,
85+
)
86+
87+
# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
88+
with torch.cpu.amp.autocast(enabled=(self._dtype == torch.bfloat16)):
89+
# Patching model with the new one
90+
self._model.model = _ModelFallbackWrapper(model, self._original)
91+
return self._model
92+
else:
93+
self._original = self._model
94+
model = ipex.optimize(
95+
self._model,
96+
dtype=self._dtype,
97+
graph_mode=self._graph_mode,
98+
level="O1",
99+
auto_kernel_selection=True,
100+
)
101+
102+
# Enable automatic mixed precision (AMP) if we are going to target `bfloat16`
103+
with torch.cpu.amp.autocast(enabled=(self._dtype == torch.bfloat16)):
104+
return model
105+
106+
def __exit__(self, exc_type, exc_val, exc_tb):
107+
self._model = self._original

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"neural-compressor": "neural-compressor>=1.13.0",
3333
"openvino": ["openvino>=2022.2.0", "transformers>=4.20.0,<4.24.1"],
3434
"nncf": ["nncf"],
35+
"ipex": ["intel_extension_for_pytorch"],
3536
"quality": QUALITY_REQUIRES,
3637
"tests": TESTS_REQUIRE,
3738
}

0 commit comments

Comments
 (0)