Skip to content

Commit 1bfed9e

Browse files
Tyler Titsworthpre-commit-ci[bot]
Tyler Titsworth
andauthored
Intel Extension for OpenXLA Containers (#385)
Signed-off-by: tylertitsworth <tyler.titsworth@intel.com> Signed-off-by: Tyler Titsworth <tyler.titsworth@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2f85bd1 commit 1bfed9e

12 files changed

+361
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ logs/
1414
models-perf/
1515
output/
1616
site
17+
test-runner-summary-output.json
1718
venv/

docs/roadmap.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
- Granite Rapids Support
1212
- CLS Support
13-
- Intel Developer Cloud Support
13+
- Intel Tiber Developer Cloud Support
1414
- AI Tools 2024.3/2025.0 Support
1515

1616
## Q4'24

docs/scripts/hook.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def create_support_matrix():
3434
compose_to_csv("pytorch", "serving")
3535
compose_to_csv("tensorflow", None)
3636
compose_to_csv("classical-ml", None)
37+
compose_to_csv("jax", None)
3738

3839
# get_repo(models)
3940
compose_to_csv("preset/data-analytics", "data_analytics")

docs/scripts/readmes.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
readmes = [
1919
"classical-ml/README.md",
20+
"jax/README.md",
2021
"preset/README.md",
2122
"python/README.md",
2223
"pytorch/README.md",

jax/.actions.json

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"PACKAGE_OPTION": ["idp", "pip"],
3+
"experimental": [true],
4+
"runner_label": ["PVC"]
5+
}

jax/Dockerfile

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
#
16+
#
17+
# This file was assembled from multiple pieces, whose use is documented
18+
# throughout. Please refer to the TensorFlow dockerfiles documentation
19+
# for more information.
20+
21+
ARG REGISTRY
22+
ARG REPO
23+
ARG GITHUB_RUN_NUMBER
24+
ARG BASE_IMAGE_NAME
25+
ARG BASE_IMAGE_TAG
26+
ARG PACKAGE_OPTION=pip
27+
ARG PYTHON_VERSION
28+
ARG PYTHON_BASE=${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER}-${BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${PACKAGE_OPTION}-py${PYTHON_VERSION}-base
29+
ARG TORCHSERVE_BASE=${PYTHON_BASE}
30+
FROM ${PYTHON_BASE} AS xpu-base
31+
32+
RUN apt-get update && \
33+
apt-get install -y --no-install-recommends --fix-missing \
34+
apt-utils \
35+
build-essential \
36+
clinfo \
37+
git \
38+
gnupg2 \
39+
gpg-agent \
40+
rsync \
41+
unzip && \
42+
apt-get clean && \
43+
rm -rf /var/lib/apt/lists/*
44+
45+
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
46+
gpg --dearmor --yes --output /usr/share/keyrings/intel-graphics.gpg
47+
RUN echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy unified" | \
48+
tee /etc/apt/sources.list.d/intel-gpu-jammy.list
49+
50+
ARG ICD_VER
51+
ARG LEVEL_ZERO_GPU_VER
52+
ARG LEVEL_ZERO_VER
53+
ARG LEVEL_ZERO_DEV_VER
54+
55+
RUN apt-get update && \
56+
apt-get install -y --no-install-recommends --fix-missing \
57+
intel-opencl-icd=${ICD_VER} \
58+
intel-level-zero-gpu=${LEVEL_ZERO_GPU_VER} \
59+
libze1=${LEVEL_ZERO_VER} \
60+
libze-dev=${LEVEL_ZERO_DEV_VER} && \
61+
rm -rf /var/lib/apt/lists/*
62+
63+
RUN no_proxy="" NO_PROXY="" wget --progress=dot:giga -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
64+
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
65+
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" \
66+
| tee /etc/apt/sources.list.d/oneAPI.list
67+
68+
ARG DPCPP_VER
69+
ARG MKL_VER
70+
ARG CCL_VER
71+
72+
RUN apt-get update && \
73+
apt-get install -y --no-install-recommends --fix-missing \
74+
intel-oneapi-runtime-dpcpp-cpp=${DPCPP_VER} \
75+
intel-oneapi-runtime-mkl=${MKL_VER} \
76+
intel-oneapi-runtime-ccl=${CCL_VER} && \
77+
rm -rf /var/lib/apt/lists/*
78+
79+
RUN rm -rf /etc/apt/sources.list.d/intel-gpu-jammy.list /etc/apt/sources.list.d/oneAPI.list
80+
81+
ENV OCL_ICD_VENDORS=/etc/OpenCL/vendors
82+
83+
FROM xpu-base AS jax-base
84+
85+
WORKDIR /
86+
COPY requirements.txt .
87+
88+
RUN python -m pip install --no-cache-dir -r requirements.txt && \
89+
rm -rf requirements.txt
90+
91+
FROM jax-base AS jupyter
92+
93+
WORKDIR /jupyter
94+
COPY jupyter-requirements.txt .
95+
96+
RUN python -m pip install --no-cache-dir -r jupyter-requirements.txt && \
97+
rm -rf jupyter-requirements.txt
98+
99+
RUN mkdir -p /jupyter/ && chmod -R a+rwx /jupyter/
100+
RUN mkdir /.local && chmod a+rwx /.local
101+
102+
EXPOSE 8888
103+
104+
CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/jupyter --port 8888 --ip 0.0.0.0 --no-browser --allow-root --ServerApp.token= --ServerApp.password= --ServerApp.allow_origin=* --ServerApp.base_url=$NB_PREFIX"]

jax/README.md

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Intel® Optimized OpenXLA\*
2+
3+
Transformable numerical computing at scale combined with [Intel® Extension for OpenXLA\*], which includes a PJRT plugin implementation to seamlessly runs [JAX\*] models on Intel GPUs.
4+
5+
## Images
6+
7+
The images below include [JAX\*] and [Intel® Extension for OpenXLA\*].
8+
9+
| Tag(s) | [JAX\*] | [Intel® Extension for OpenXLA\*] | [Flax] | Dockerfile |
10+
| -------------------------- | --------- | ----------------- | -------- | --------------- |
11+
| `0.4.0-pip-base`, `latest` | [v0.4.32] | [v0.4.0-jax] | [v0.9.0] | [v0.4.0] |
12+
13+
The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
14+
15+
| Tag(s) | [JAX\*] | [Intel® Extension for OpenXLA\*] | [Flax] | Dockerfile |
16+
| ------------------- | --------- | ----------------- | -------- | --------------- |
17+
| `0.4.0-pip-jupyter` | [v0.4.32] | [v0.4.0-jax] | [v0.9.0] | [v0.4.0] |
18+
19+
### Run the Jupyter Container
20+
21+
```bash
22+
docker run -it --rm \
23+
-p 8888:8888 \
24+
--net=host \
25+
-v $PWD/workspace:/workspace \
26+
-w /workspace \
27+
intel/intel-optimized-xla:0.4.0-pip-jupyter
28+
```
29+
30+
After running the command above, copy the URL (something like `http://127.0.0.1:$PORT/?token=***`) into your browser to access the notebook server.
31+
32+
## Images with Intel® Distribution for Python*
33+
34+
The images below include [Intel® Distribution for Python*]:
35+
36+
| Tag(s) | [JAX\*] | [Intel® Extension for OpenXLA\*] | [Flax] | Dockerfile |
37+
| ---------------- | --------- | ----------------- | -------- | --------------- |
38+
| `0.4.0-idp-base` | [v0.4.32] | [v0.4.0-jax] | [v0.9.0] | [v0.4.0] |
39+
40+
The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
41+
42+
| Tag(s) | [JAX\*] | [Intel® Extension for OpenXLA\*] | [Flax] | Dockerfile |
43+
| ------------------- | --------- | ----------------- | -------- | --------------- |
44+
| `0.4.0-idp-jupyter` | [v0.4.32] | [v0.4.0-jax] | [v0.9.0] | [v0.4.0] |
45+
46+
## Build from Source
47+
48+
To build the images from source, clone the [AI Containers](https://github.com/intel/ai-containers) repository, follow the main `README.md` file to setup your environment, and run the following command:
49+
50+
```bash
51+
cd jax
52+
docker compose build jax-base
53+
docker compose run -it jax-base
54+
```
55+
56+
You can find the list of services below for each container in the group:
57+
58+
| Service Name | Description |
59+
| ------------ | ----------------------------------------------- |
60+
| `jax-base` | Base image with [Intel® Extension for OpenXLA\*] |
61+
| `jupyter` | Adds Jupyter Notebook server |
62+
63+
## License
64+
65+
View the [License](https://github.com/intel/ai-containers/blob/main/LICENSE) for the [Intel® Distribution for Python].
66+
67+
The images below also contain other software which may be under other licenses (such as Pytorch*, Jupyter*, Bash, etc. from the base).
68+
69+
It is the image user's responsibility to ensure that any use of The images below comply with any relevant licenses for all software contained within.
70+
71+
\* Other names and brands may be claimed as the property of others.
72+
73+
<!--Below are links used in these document. They are not rendered: -->
74+
75+
[Intel® Distribution for Python*]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/distribution-for-python.html#gs.9bos9m
76+
[Intel® Extension for OpenXLA\*]: https://github.com/intel/intel-extension-for-openxla
77+
[JAX\*]: https://github.com/google/jax
78+
[Flax]: https://github.com/google/flax
79+
80+
[v0.4.32]: https://github.com/google/jax/releases/tag/jax-v0.4.32
81+
82+
[v0.4.0-jax]: https://github.com/intel/intel-extension-for-openxla/releases/tag/0.4.0
83+
84+
[v0.9.0]: https://github.com/google/Flax/releases/tag/v0.9.0
85+
86+
[v0.4.0]: https://github.com/intel/ai-containers/blob/v0.4.0/jax/Dockerfile

jax/docker-compose.yaml

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
include:
16+
- path:
17+
- ../python/docker-compose.yaml
18+
services:
19+
jax-base:
20+
build:
21+
args:
22+
http_proxy: ${http_proxy}
23+
https_proxy: ${https_proxy}
24+
no_proxy: ""
25+
BASE_IMAGE_NAME: ${BASE_IMAGE_NAME:-ubuntu}
26+
BASE_IMAGE_TAG: ${BASE_IMAGE_TAG:-22.04}
27+
CCL_VER: ${CCL_VER:-2021.13.1-31}
28+
DPCPP_VER: ${DPCPP_VER:-2024.2.1-1079}
29+
GITHUB_RUN_NUMBER: ${GITHUB_RUN_NUMBER:-0}
30+
ICD_VER: ${ICD_VER:-24.22.29735.27-914~22.04}
31+
LEVEL_ZERO_DEV_VER: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
32+
LEVEL_ZERO_GPU_VER: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
33+
LEVEL_ZERO_VER: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
34+
MINIFORGE_VERSION: ${MINIFORGE_VERSION:-Linux-x86_64}
35+
MKL_VER: ${MKL_VER:-2024.2.1-103}
36+
NO_PROXY: ''
37+
PACKAGE_OPTION: ${PACKAGE_OPTION:-pip}
38+
PYTHON_VERSION: ${PYTHON_VERSION:-3.10}
39+
REGISTRY: ${REGISTRY}
40+
REPO: ${REPO}
41+
context: .
42+
labels:
43+
dependency.python: ${PYTHON_VERSION:-3.10}
44+
dependency.apt.build-essential: true
45+
dependency.apt.clinfo: true
46+
dependency.apt.git: true
47+
dependency.apt.gnupg2: true
48+
dependency.apt.gpg-agent: true
49+
dependency.apt.intel-level-zero-gpu: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
50+
dependency.apt.intel-oneapi-runtime-ccl: ${CCL_VER:-2021.13.1-31}
51+
dependency.apt.intel-oneapi-runtime-dpcpp-cpp: ${DPCPP_VER:-2024.2.1-1079}
52+
dependency.apt.intel-oneapi-runtime-mkl: ${MKL_VER:-2024.2.1-103}
53+
dependency.apt.intel-opencl-icd: ${ICD_VER:-23.43.27642.40-803~22.04}
54+
dependency.apt.level-zero: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
55+
dependency.apt.level-zero-dev: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
56+
dependency.apt.rsync: true
57+
dependency.apt.unzip: true
58+
dependency.idp.pip: false
59+
dependency.python.pip: requirements.txt
60+
docs: jax
61+
org.opencontainers.base.name: "intel/python:3.10-core"
62+
org.opencontainers.image.name: "intel/intel-optimized-xla"
63+
org.opencontainers.image.title: "Intel® Optimized XLA Base Image"
64+
org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-${PACKAGE_OPTION:-pip}-base
65+
target: jax-base
66+
command: >
67+
bash -c "python -c 'import jax; print(\"Jax Version:\", jax.__version__)'"
68+
depends_on:
69+
- ${PACKAGE_OPTION:-pip}
70+
image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
71+
pull_policy: always
72+
jupyter:
73+
build:
74+
labels:
75+
dependency.python.pip: jupyter-requirements.txt
76+
org.opencontainers.base.name: "intel/intel-optimized-xla:${INTEL_XLA_VERSION:-v0.4.0}-base"
77+
org.opencontainers.image.title: "Intel® Optimized XLA Jupyter Base Image"
78+
org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-jupyter
79+
target: jupyter
80+
command: >
81+
bash -c "python -m jupyter --version"
82+
environment:
83+
http_proxy: ${http_proxy}
84+
https_proxy: ${https_proxy}
85+
extends: jax-base
86+
image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
87+
network_mode: host

jax/jupyter-requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
jupyterlab==4.2.4
2+
jupyterhub==5.1.0
3+
notebook==7.2.1
4+
jupyter-server-proxy>=4.1.2

jax/requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
flax==0.8.2
2+
intel-extension-for-openxla==0.4.0
3+
jax==0.4.26
4+
jaxlib==0.4.26
5+
cython==3.0.11

jax/tests/example.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pylint: skip-file
16+
17+
import jax
18+
import jax.numpy as jnp
19+
20+
print("jax.local_devices(): ", jax.local_devices())
21+
22+
23+
@jax.jit
24+
def lax_conv():
25+
key = jax.random.PRNGKey(0)
26+
lhs = jax.random.uniform(key, (2, 1, 9, 9), jnp.float32)
27+
rhs = jax.random.uniform(key, (1, 1, 4, 4), jnp.float32)
28+
side = jax.random.uniform(key, (1, 1, 1, 1), jnp.float32)
29+
out = jax.lax.conv_with_general_padding(
30+
lhs, rhs, (1, 1), ((0, 0), (0, 0)), (1, 1), (1, 1)
31+
)
32+
out = jax.nn.relu(out)
33+
out = jnp.multiply(out, side)
34+
return out
35+
36+
37+
print(lax_conv())

jax/tests/tests.yaml

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
---
16+
jax-import-${PACKAGE_OPTION:-pip}:
17+
img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
18+
cmd: python -c 'import jax; print("Jax Version:", jax.__version__); print(jax.devices())'
19+
device: ["/dev/dri"]
20+
jax-import-jupyter-${PACKAGE_OPTION:-pip}:
21+
img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
22+
cmd: sh -c "python -m jupyter --version"
23+
jax-xpu-example-${PACKAGE_OPTION:-pip}:
24+
img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
25+
cmd: python /tests/example.py
26+
device: ["/dev/dri"]
27+
volumes:
28+
- src: $PWD/jax/tests
29+
dst: /tests

0 commit comments

Comments
 (0)