From e7c3ce5229672f3921777fdb3d9016662de34305 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Wed, 28 Aug 2024 15:27:53 -0700
Subject: [PATCH 01/12] init jax

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/.actions.json            |   4 ++
 jax/Dockerfile               |  68 ++++++++++++++++++++
 jax/README.md                | 117 +++++++++++++++++++++++++++++++++++
 jax/docker-compose.yaml      |  53 ++++++++++++++++
 jax/jupyter-requirements.txt |   4 ++
 jax/requirements.txt         |   4 ++
 jax/tests/example.py         |  37 +++++++++++
 jax/tests/tests.yaml         |  28 +++++++++
 8 files changed, 315 insertions(+)
 create mode 100644 jax/.actions.json
 create mode 100644 jax/Dockerfile
 create mode 100644 jax/README.md
 create mode 100644 jax/docker-compose.yaml
 create mode 100644 jax/jupyter-requirements.txt
 create mode 100644 jax/requirements.txt
 create mode 100644 jax/tests/example.py
 create mode 100644 jax/tests/tests.yaml

diff --git a/jax/.actions.json b/jax/.actions.json
new file mode 100644
index 00000000..dee58d43
--- /dev/null
+++ b/jax/.actions.json
@@ -0,0 +1,4 @@
+{
+    "experimental": [true],
+    "runner_label": ["XEON"]
+}
diff --git a/jax/Dockerfile b/jax/Dockerfile
new file mode 100644
index 00000000..998dbb92
--- /dev/null
+++ b/jax/Dockerfile
@@ -0,0 +1,68 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG BASE_IMAGE_NAME=${BASE_IMAGE_NAME}
+ARG BASE_IMAGE_TAG=${BASE_IMAGE_TAG}
+FROM ${BASE_IMAGE_NAME}:${BASE_IMAGE_TAG} AS xpu-base
+
+RUN wget -q -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
+    echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
+    chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
+    rm /etc/apt/sources.list.d/intel-graphics.list && \
+    wget -q -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
+    echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
+    chmod 644 /usr/share/keyrings/intel-graphics.gpg
+
+RUN apt-get update -y && \
+    apt-get install -y --no-install-recommends --fix-missing \
+    git \
+    # libsndfile1 \
+    # lsb-release \
+    numactl \
+    python3 \
+    python3-dev \
+    python3-pip
+
+RUN ln -sf "$(which python3)" /usr/local/bin/python && \
+    ln -sf "$(which python3)" /usr/bin/python
+
+FROM xpu-base AS jax-base
+
+WORKDIR /
+COPY requirements.txt .
+
+RUN python -m pip install --no-cache-dir \
+    --ignore-installed -r requirements.txt && \
+    rm -rf requirements.txt
+
+FROM jax-base AS jupyter
+
+WORKDIR /jupyter
+COPY jupyter-requirements.txt .
+
+RUN python -m pip install --no-cache-dir -r jupyter-requirements.txt && \
+    rm -rf jupyter-requirements.txt
+
+RUN mkdir -p /jupyter/ && chmod -R a+rwx /jupyter/
+RUN mkdir /.local && chmod a+rwx /.local
+
+EXPOSE 8888
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/jupyter --port 8888 --ip 0.0.0.0 --no-browser --allow-root --ServerApp.token= --ServerApp.password= --ServerApp.allow_origin=* --ServerApp.base_url=$NB_PREFIX"]
diff --git a/jax/README.md b/jax/README.md
new file mode 100644
index 00000000..97cc50c2
--- /dev/null
+++ b/jax/README.md
@@ -0,0 +1,117 @@
+# Intel® Optimized ML
+
+[Intel® Extension for Scikit-learn*] enhances the performance of [Scikit-learn*] by accelerating the training and inference of machine learning models on Intel® hardware.
+
+[XGBoost*] is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable.
+
+## Images
+
+The images below include [Intel® Extension for Scikit-learn*] and [XGBoost*].
+
+| Tag(s)                                            | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
+| ------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
+| `2024.6.0-pip-base`, `latest`                     | [v2024.6.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.5.0-pip-base`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.3.0-pip-base`                               | [v2024.3.0]    | [v1.4.2]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `2024.2.0-xgboost-2.0.3-pip-base`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `scikit-learning-2024.0.0-xgboost-2.0.2-pip-base` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+
+The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
+
+| Tag(s)                                               | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
+| ---------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
+| `2024.6.0-pip-jupyter`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
+| `2024.5.0-pip-jupyter`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.3.0-pip-jupyter`                               | [v2024.3.0]    | [v1.4.2]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `2024.2.0-xgboost-2.0.3-pip-jupyter`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `scikit-learning-2024.0.0-xgboost-2.0.2-pip-jupyter` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+
+### Run the Jupyter Container
+
+```bash
+docker run -it --rm \
+    -p 8888:8888 \
+    --net=host \
+    -v $PWD/workspace:/workspace \
+    -w /workspace \
+    intel/intel-optimized-ml:2024.2.0-xgboost-2.0.3-pip-jupyter
+```
+
+After running the command above, copy the URL (something like `http://127.0.0.1:$PORT/?token=***`) into your browser to access the notebook server.
+
+## Images with Intel® Distribution for Python*
+
+The images below include [Intel® Distribution for Python*]:
+
+| Tag(s)                                            | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
+| ------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
+| `2024.6.0-idp-base`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
+| `2024.5.0-idp-base`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.3.0-idp-base`                               | [v2024.3.0]    | [v1.4.1]     | [v2.1.0] | [v0.4.0]        |
+| `2024.2.0-xgboost-2.0.3-idp-base`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `scikit-learning-2024.0.0-xgboost-2.0.2-idp-base` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+
+The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
+
+| Tag(s)                                               | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
+| ---------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
+| `2024.6.0-idp-jupyter`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
+| `2024.5.0-idp-jupyter`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.3.0-idp-jupyter`                               | [v2024.3.0]    | [v1.4.0]     | [v2.1.0] | [v0.4.0]        |
+| `2024.2.0-xgboost-2.0.3-idp-jupyter`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
+| `scikit-learning-2024.0.0-xgboost-2.0.2-idp-jupyter` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+
+## Build from Source
+
+To build the images from source, clone the [AI Containers](https://github.com/intel/ai-containers) repository, follow the main `README.md` file to setup your environment, and run the following command:
+
+```bash
+cd classical-ml
+docker compose build ml-base
+docker compose run ml-base
+```
+
+You can find the list of services below for each container in the group:
+
+| Service Name | Description                                                         |
+| ------------ | ------------------------------------------------------------------- |
+| `ml-base`    | Base image with [Intel® Extension for Scikit-learn*] and [XGBoost*] |
+| `jupyter`    | Adds Jupyter Notebook server                                        |
+
+## License
+
+View the [License](https://github.com/intel/ai-containers/blob/main/LICENSE) for the [Intel® Distribution for Python].
+
+The images below also contain other software which may be under other licenses (such as Pytorch*, Jupyter*, Bash, etc. from the base).
+
+It is the image user's responsibility to ensure that any use of The images below comply with any relevant licenses for all software contained within.
+
+\* Other names and brands may be claimed as the property of others.
+
+<!--Below are links used in these document. They are not rendered: -->
+
+[Intel® Extension for Scikit-learn*]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/scikit-learn.html
+[Intel® Distribution for Python]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/distribution-for-python.html#gs.9bos9m
+[Scikit-learn*]: https://scikit-learn.org/stable/
+[XGBoost*]: https://github.com/dmlc/xgboost
+
+[v2024.6.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.6.0
+[v2024.5.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.5.0
+[v2024.3.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.3.0
+[v2024.2.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.2.0
+[v2024.0.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.0.0
+
+[v1.5.1]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.5.1
+[v1.5.0]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.5.0
+[v1.4.2]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.4.2
+[v1.4.1]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.4.1
+[v1.3.2]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.3.2
+
+[v2.1.1]: https://github.com/dmlc/xgboost/releases/tag/v2.1.1
+[v2.1.0]: https://github.com/dmlc/xgboost/releases/tag/v2.1.0
+[v2.0.3]: https://github.com/dmlc/xgboost/releases/tag/v2.0.3
+[v2.0.2]: https://github.com/dmlc/xgboost/releases/tag/v2.0.2
+
+[v0.4.0]: https://github.com/intel/ai-containers/blob/v0.4.0/classical-ml/Dockerfile
+[v0.4.0-Beta]: https://github.com/intel/ai-containers/blob/v0.4.0-Beta/classical-ml/Dockerfile
+[v0.3.4]: https://github.com/intel/ai-containers/blob/v0.3.4/classical-ml/Dockerfile
diff --git a/jax/docker-compose.yaml b/jax/docker-compose.yaml
new file mode 100644
index 00000000..7544bf5d
--- /dev/null
+++ b/jax/docker-compose.yaml
@@ -0,0 +1,53 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+services:
+  jax-base:
+    build:
+      args:
+        http_proxy: ${http_proxy}
+        https_proxy: ${https_proxy}
+        no_proxy: ""
+        BASE_IMAGE_NAME: ${BASE_IMAGE_NAME:-intel/oneapi-basekit}
+        BASE_IMAGE_TAG: ${BASE_IMAGE_TAG:-2024.2.0-devel-ubuntu22.04}
+      context: .
+      labels:
+        dependency.python: ${PYTHON_VERSION:-3.10}
+        dependency.python.pip: requirements.txt
+        docs: jax
+        org.opencontainers.base.name: "intel/python:3.10-core"
+        org.opencontainers.image.name: "intel/intel-optimized-xla"
+        org.opencontainers.image.title: "Intel® Optimized XLA Base Image"
+        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-base
+      target: jax-base
+    command: >
+      bash -c "python -c 'import jax; print(\"Jax Version:\", jax.__version__)'"
+    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+    pull_policy: always
+  jupyter:
+    build:
+      labels:
+        dependency.python.pip: jupyter-requirements.txt
+        org.opencontainers.base.name: "intel/intel-optimized-xla:${INTEL_XLA_VERSION:-v0.4.0}-base"
+        org.opencontainers.image.title: "Intel® Optimized XLA Jupyter Base Image"
+        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-jupyter
+      target: jupyter
+    command: >
+      bash -c "python -m jupyter --version"
+    environment:
+      http_proxy: ${http_proxy}
+      https_proxy: ${https_proxy}
+    extends: jax-base
+    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
+    network_mode: host
diff --git a/jax/jupyter-requirements.txt b/jax/jupyter-requirements.txt
new file mode 100644
index 00000000..2cae0f91
--- /dev/null
+++ b/jax/jupyter-requirements.txt
@@ -0,0 +1,4 @@
+jupyterlab==4.2.4
+jupyterhub==5.1.0
+notebook==7.2.1
+jupyter-server-proxy>=4.1.2
diff --git a/jax/requirements.txt b/jax/requirements.txt
new file mode 100644
index 00000000..968728fd
--- /dev/null
+++ b/jax/requirements.txt
@@ -0,0 +1,4 @@
+flax==0.8.2
+intel-extension-for-openxla==0.4.0
+jax==0.4.26
+jaxlib==0.4.26
diff --git a/jax/tests/example.py b/jax/tests/example.py
new file mode 100644
index 00000000..9227d066
--- /dev/null
+++ b/jax/tests/example.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+
+import jax
+import jax.numpy as jnp
+
+print("jax.local_devices(): ", jax.local_devices())
+
+
+@jax.jit
+def lax_conv():
+    key = jax.random.PRNGKey(0)
+    lhs = jax.random.uniform(key, (2, 1, 9, 9), jnp.float32)
+    rhs = jax.random.uniform(key, (1, 1, 4, 4), jnp.float32)
+    side = jax.random.uniform(key, (1, 1, 1, 1), jnp.float32)
+    out = jax.lax.conv_with_general_padding(
+        lhs, rhs, (1, 1), ((0, 0), (0, 0)), (1, 1), (1, 1)
+    )
+    out = jax.nn.relu(out)
+    out = jnp.multiply(out, side)
+    return out
+
+
+print(lax_conv())
diff --git a/jax/tests/tests.yaml b/jax/tests/tests.yaml
new file mode 100644
index 00000000..65e620e6
--- /dev/null
+++ b/jax/tests/tests.yaml
@@ -0,0 +1,28 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+---
+jax-import:
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+  cmd: python -c 'import jax; print(\"Jax Version:\", jax.__version__)'
+jax-import-jupyter:
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
+  cmd: sh -c "python -m jupyter --version"
+jax-xpu-example:
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+  cmd: python /tests/example.py
+  device: ["/dev/dri"]
+  volumes:
+    - src: $PWD/jax/tests
+      dst: /tests

From 994c7d764b9b7eb9723990f459e7bdc946cb3f80 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Wed, 28 Aug 2024 15:29:01 -0700
Subject: [PATCH 02/12] init jax

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 docs/scripts/hook.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/docs/scripts/hook.py b/docs/scripts/hook.py
index 3b862bdf..2f0c96ec 100644
--- a/docs/scripts/hook.py
+++ b/docs/scripts/hook.py
@@ -34,6 +34,7 @@ def create_support_matrix():
     compose_to_csv("pytorch", "serving")
     compose_to_csv("tensorflow", None)
     compose_to_csv("classical-ml", None)
+    compose_to_csv("jax", None)
 
     # get_repo(models)
     compose_to_csv("preset/data-analytics", "data_analytics")

From 233117e38985c2e93c4ca3c060523e46f68a70d2 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Wed, 28 Aug 2024 15:34:33 -0700
Subject: [PATCH 03/12] update tests

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/.actions.json    | 2 +-
 jax/tests/tests.yaml | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/jax/.actions.json b/jax/.actions.json
index dee58d43..45608421 100644
--- a/jax/.actions.json
+++ b/jax/.actions.json
@@ -1,4 +1,4 @@
 {
     "experimental": [true],
-    "runner_label": ["XEON"]
+    "runner_label": ["PVC"]
 }
diff --git a/jax/tests/tests.yaml b/jax/tests/tests.yaml
index 65e620e6..03afa463 100644
--- a/jax/tests/tests.yaml
+++ b/jax/tests/tests.yaml
@@ -15,7 +15,7 @@
 ---
 jax-import:
   img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
-  cmd: python -c 'import jax; print(\"Jax Version:\", jax.__version__)'
+  cmd: python -c 'import jax; print("Jax Version:", jax.__version__)'
 jax-import-jupyter:
   img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
   cmd: sh -c "python -m jupyter --version"

From 58c274b0e930fd9de4d74922104864c58f2fa81f Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Thu, 12 Sep 2024 15:10:03 -0700
Subject: [PATCH 04/12] add multinode img

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/Dockerfile                      | 55 +++++++++++++++++++++++++++++
 jax/docker-compose.yaml             | 18 ++++++++++
 jax/multinode/dockerd-entrypoint.sh | 21 +++++++++++
 jax/multinode/generate_ssh_keys.sh  | 28 +++++++++++++++
 jax/multinode/requirements.txt      |  3 ++
 jax/multinode/ssh_config            |  4 +++
 jax/multinode/sshd_config           | 12 +++++++
 jax/requirements.txt                |  1 +
 jax/tests/multinode-ex.py           | 37 +++++++++++++++++++
 jax/tests/tests.yaml                |  7 ++++
 10 files changed, 186 insertions(+)
 create mode 100755 jax/multinode/dockerd-entrypoint.sh
 create mode 100755 jax/multinode/generate_ssh_keys.sh
 create mode 100644 jax/multinode/requirements.txt
 create mode 100644 jax/multinode/ssh_config
 create mode 100644 jax/multinode/sshd_config
 create mode 100644 jax/tests/multinode-ex.py

diff --git a/jax/Dockerfile b/jax/Dockerfile
index 998dbb92..a7d8b92b 100644
--- a/jax/Dockerfile
+++ b/jax/Dockerfile
@@ -66,3 +66,58 @@ RUN mkdir /.local && chmod a+rwx /.local
 EXPOSE 8888
 
 CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/jupyter --port 8888 --ip 0.0.0.0 --no-browser --allow-root --ServerApp.token= --ServerApp.password= --ServerApp.allow_origin=* --ServerApp.base_url=$NB_PREFIX"]
+
+FROM jax-base AS jax-multinode
+
+RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
+    # python3-dev \
+    gcc \
+    # g++ \
+    # libgl1-mesa-glx \
+    # libglib2.0-0 \
+    libopenmpi-dev \
+    numactl \
+    virtualenv
+
+ENV SIGOPT_PROJECT=. \
+    MPI4JAX_USE_SYCL_MPI=1 \
+    MPI4PY_BUILD_BACKEND=scikit-build-core
+
+WORKDIR /
+COPY multinode/requirements.txt requirements.txt
+
+
+RUN python -m pip install --no-cache-dir -r requirements.txt && \
+    rm -rf requirements.txt
+
+ENV LD_LIBRARY_PATH="/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}"
+
+RUN apt-get install -y --no-install-recommends --fix-missing \
+    openssh-client \
+    openssh-server && \
+    rm /etc/ssh/ssh_host_*_key \
+    /etc/ssh/ssh_host_*_key.pub && \
+    apt-get clean && \
+    rm -rf /var/lib/apt/lists/*
+
+RUN mkdir -p /var/run/sshd
+
+ARG PYTHON_VERSION
+
+COPY multinode/generate_ssh_keys.sh /generate_ssh_keys.sh
+
+# modify generate_ssh_keys to be a helper script
+# print how to use helper script on bash startup
+# Avoids loop for further execution of the startup file
+ARG PACKAGE_OPTION=pip
+ARG PYPATH="/usr/local/lib/python${PYTHON_VERSION}/dist-packages"
+RUN cat '/generate_ssh_keys.sh' >> ~/.startup && \
+    rm -rf /generate_ssh_keys.sh
+
+COPY multinode/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
+COPY multinode/sshd_config /etc/ssh/sshd_config
+COPY multinode/ssh_config /etc/ssh/ssh_config
+
+RUN mkdir -p /licensing
+
+ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]
diff --git a/jax/docker-compose.yaml b/jax/docker-compose.yaml
index 7544bf5d..1cec9fab 100644
--- a/jax/docker-compose.yaml
+++ b/jax/docker-compose.yaml
@@ -51,3 +51,21 @@ services:
     extends: jax-base
     image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
     network_mode: host
+  multinode:
+    build:
+      labels:
+        # dependency.apt.gcc: true
+        # dependency.apt.libgl1-mesa-glx: true
+        # dependency.apt.libglib2: true
+        dependency.apt.python3-dev: true
+        dependency.pip.apt.virtualenv: true
+        dependency.python.pip: multinode/requirements.txt
+        org.opencontainers.base.name: "intel/intel-optimized-xla:${INTEL_XLA_VERSION:-v0.4.0}-base"
+        org.opencontainers.image.title: "Intel® Optimized XLA MultiNode Image"
+        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-multinode
+      target: jax-multinode
+    command: >
+      bash -c "mpirun --version && python -c 'from mpi4py import MPI; import mpi4jax'"
+    extends: jax-base
+    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-inc-${INC_VERSION:-3.0}
+    shm_size: 2gb
diff --git a/jax/multinode/dockerd-entrypoint.sh b/jax/multinode/dockerd-entrypoint.sh
new file mode 100755
index 00000000..ba13c0f9
--- /dev/null
+++ b/jax/multinode/dockerd-entrypoint.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -e
+set -a
+# shellcheck disable=SC1091
+source "$HOME/.startup"
+set +a
+"$@"
diff --git a/jax/multinode/generate_ssh_keys.sh b/jax/multinode/generate_ssh_keys.sh
new file mode 100755
index 00000000..0ee61398
--- /dev/null
+++ b/jax/multinode/generate_ssh_keys.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+
+function gen_single_key() {
+	ALG_NAME=$1
+	if [[ ! -f /etc/ssh/ssh_host_${ALG_NAME}_key ]]; then
+		ssh-keygen -q -N "" -t "${ALG_NAME}" -f "/etc/ssh/ssh_host_${ALG_NAME}_key"
+	fi
+}
+
+gen_single_key dsa
+gen_single_key rsa
+gen_single_key ecdsa
+gen_single_key ed25519
diff --git a/jax/multinode/requirements.txt b/jax/multinode/requirements.txt
new file mode 100644
index 00000000..47bd5f70
--- /dev/null
+++ b/jax/multinode/requirements.txt
@@ -0,0 +1,3 @@
+neural-compressor==3.0
+# mpi4py>=3.1.0
+mpi4jax>=0.5.3
diff --git a/jax/multinode/ssh_config b/jax/multinode/ssh_config
new file mode 100644
index 00000000..9ac73017
--- /dev/null
+++ b/jax/multinode/ssh_config
@@ -0,0 +1,4 @@
+Host *
+    Port 3022
+    IdentityFile ~/.ssh/id_rsa
+    StrictHostKeyChecking no
diff --git a/jax/multinode/sshd_config b/jax/multinode/sshd_config
new file mode 100644
index 00000000..4796a48a
--- /dev/null
+++ b/jax/multinode/sshd_config
@@ -0,0 +1,12 @@
+HostKey /etc/ssh/ssh_host_dsa_key
+HostKey /etc/ssh/ssh_host_rsa_key
+HostKey /etc/ssh/ssh_host_ecdsa_key
+HostKey /etc/ssh/ssh_host_ed25519_key
+AuthorizedKeysFile /etc/ssh/authorized_keys
+## Enable DEBUG log. You can ignore this but this may help you debug any issue while enabling SSHD for the first time
+LogLevel DEBUG3
+Port 3022
+UsePAM yes
+Subsystem       sftp    /usr/lib/openssh/sftp-server
+# https://ubuntu.com/security/CVE-2024-6387
+LoginGraceTime 0
diff --git a/jax/requirements.txt b/jax/requirements.txt
index 968728fd..09d7cb7f 100644
--- a/jax/requirements.txt
+++ b/jax/requirements.txt
@@ -2,3 +2,4 @@ flax==0.8.2
 intel-extension-for-openxla==0.4.0
 jax==0.4.26
 jaxlib==0.4.26
+cython==3.0.11
diff --git a/jax/tests/multinode-ex.py b/jax/tests/multinode-ex.py
new file mode 100644
index 00000000..2a05d913
--- /dev/null
+++ b/jax/tests/multinode-ex.py
@@ -0,0 +1,37 @@
+# Copyright (c) 2024 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: skip-file
+
+import jax
+import jax.numpy as jnp
+import mpi4jax
+from mpi4py import MPI
+
+comm = MPI.COMM_WORLD
+rank = comm.Get_rank()
+
+
+@jax.jit
+def foo(arr):
+    arr = arr + rank
+    arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
+    return arr_sum
+
+
+a = jnp.zeros((3, 3))
+result = foo(a)
+
+if rank == 0:
+    print(result)
diff --git a/jax/tests/tests.yaml b/jax/tests/tests.yaml
index 03afa463..8607cff0 100644
--- a/jax/tests/tests.yaml
+++ b/jax/tests/tests.yaml
@@ -26,3 +26,10 @@ jax-xpu-example:
   volumes:
     - src: $PWD/jax/tests
       dst: /tests
+jax-multinode-example:
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-inc-${INC_VERSION:-3.0}
+  cmd: python /tests/multinode-ex.py
+  device: ["/dev/dri"]
+  volumes:
+    - src: $PWD/jax/tests
+      dst: /tests

From 3621b155ceebc90abac128c4453589d613355d63 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 12:48:53 -0700
Subject: [PATCH 05/12] remove multinode

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 .gitignore                          |   1 +
 jax/.actions.json                   |   1 +
 jax/Dockerfile                      | 137 ++++++++++++----------------
 jax/docker-compose.yaml             |  39 ++++----
 jax/multinode/dockerd-entrypoint.sh |  21 -----
 jax/multinode/generate_ssh_keys.sh  |  28 ------
 jax/multinode/requirements.txt      |   3 -
 jax/multinode/ssh_config            |   4 -
 jax/multinode/sshd_config           |  12 ---
 jax/tests/multinode-ex.py           |  37 --------
 jax/tests/tests.yaml                |  16 +---
 11 files changed, 85 insertions(+), 214 deletions(-)
 delete mode 100755 jax/multinode/dockerd-entrypoint.sh
 delete mode 100755 jax/multinode/generate_ssh_keys.sh
 delete mode 100644 jax/multinode/requirements.txt
 delete mode 100644 jax/multinode/ssh_config
 delete mode 100644 jax/multinode/sshd_config
 delete mode 100644 jax/tests/multinode-ex.py

diff --git a/.gitignore b/.gitignore
index f3e1d08f..d6f27a92 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,4 +14,5 @@ logs/
 models-perf/
 output/
 site
+test-runner-summary-output.json
 venv/
diff --git a/jax/.actions.json b/jax/.actions.json
index 45608421..36e21ad8 100644
--- a/jax/.actions.json
+++ b/jax/.actions.json
@@ -1,4 +1,5 @@
 {
+    "PACKAGE_OPTION": ["idp", "pip"],
     "experimental": [true],
     "runner_label": ["PVC"]
 }
diff --git a/jax/Dockerfile b/jax/Dockerfile
index a7d8b92b..0d0118fa 100644
--- a/jax/Dockerfile
+++ b/jax/Dockerfile
@@ -18,38 +18,74 @@
 # throughout. Please refer to the TensorFlow dockerfiles documentation
 # for more information.
 
-ARG BASE_IMAGE_NAME=${BASE_IMAGE_NAME}
-ARG BASE_IMAGE_TAG=${BASE_IMAGE_TAG}
-FROM ${BASE_IMAGE_NAME}:${BASE_IMAGE_TAG} AS xpu-base
-
-RUN wget -q -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
-    echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
-    chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
-    rm /etc/apt/sources.list.d/intel-graphics.list && \
-    wget -q -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
-    echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
-    chmod 644 /usr/share/keyrings/intel-graphics.gpg
-
-RUN apt-get update -y && \
+ARG REGISTRY
+ARG REPO
+ARG GITHUB_RUN_NUMBER
+ARG BASE_IMAGE_NAME
+ARG BASE_IMAGE_TAG
+ARG PACKAGE_OPTION=pip
+ARG PYTHON_VERSION
+ARG PYTHON_BASE=${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER}-${BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${PACKAGE_OPTION}-py${PYTHON_VERSION}-base
+ARG TORCHSERVE_BASE=${PYTHON_BASE}
+FROM ${PYTHON_BASE} AS xpu-base
+
+RUN apt-get update && \
     apt-get install -y --no-install-recommends --fix-missing \
+    apt-utils \
+    build-essential \
+    clinfo \
     git \
-    # libsndfile1 \
-    # lsb-release \
-    numactl \
-    python3 \
-    python3-dev \
-    python3-pip
+    gnupg2 \
+    gpg-agent \
+    rsync \
+    unzip && \
+    apt-get clean && \
+    rm -rf  /var/lib/apt/lists/*
+
+RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \
+    gpg --dearmor --yes --output /usr/share/keyrings/intel-graphics.gpg
+RUN echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy unified" | \
+    tee /etc/apt/sources.list.d/intel-gpu-jammy.list
+
+ARG ICD_VER
+ARG LEVEL_ZERO_GPU_VER
+ARG LEVEL_ZERO_VER
+ARG LEVEL_ZERO_DEV_VER
+
+RUN apt-get update && \
+    apt-get install -y --no-install-recommends --fix-missing \
+    intel-opencl-icd=${ICD_VER} \
+    intel-level-zero-gpu=${LEVEL_ZERO_GPU_VER} \
+    libze1=${LEVEL_ZERO_VER} \
+    libze-dev=${LEVEL_ZERO_DEV_VER} && \
+    rm -rf  /var/lib/apt/lists/*
+
+RUN no_proxy="" NO_PROXY="" wget --progress=dot:giga -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
+    | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \
+    echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" \
+    | tee /etc/apt/sources.list.d/oneAPI.list
+
+ARG DPCPP_VER
+ARG MKL_VER
+ARG CCL_VER
+
+RUN apt-get update && \
+    apt-get install -y --no-install-recommends --fix-missing \
+    intel-oneapi-runtime-dpcpp-cpp=${DPCPP_VER} \
+    intel-oneapi-runtime-mkl=${MKL_VER} \
+    intel-oneapi-runtime-ccl=${CCL_VER} && \
+    rm -rf  /var/lib/apt/lists/*
+
+RUN rm -rf /etc/apt/sources.list.d/intel-gpu-jammy.list /etc/apt/sources.list.d/oneAPI.list
 
-RUN ln -sf "$(which python3)" /usr/local/bin/python && \
-    ln -sf "$(which python3)" /usr/bin/python
+ENV OCL_ICD_VENDORS=/etc/OpenCL/vendors
 
 FROM xpu-base AS jax-base
 
 WORKDIR /
 COPY requirements.txt .
 
-RUN python -m pip install --no-cache-dir \
-    --ignore-installed -r requirements.txt && \
+RUN python -m pip install --no-cache-dir -r requirements.txt && \
     rm -rf requirements.txt
 
 FROM jax-base AS jupyter
@@ -66,58 +102,3 @@ RUN mkdir /.local && chmod a+rwx /.local
 EXPOSE 8888
 
 CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/jupyter --port 8888 --ip 0.0.0.0 --no-browser --allow-root --ServerApp.token= --ServerApp.password= --ServerApp.allow_origin=* --ServerApp.base_url=$NB_PREFIX"]
-
-FROM jax-base AS jax-multinode
-
-RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
-    # python3-dev \
-    gcc \
-    # g++ \
-    # libgl1-mesa-glx \
-    # libglib2.0-0 \
-    libopenmpi-dev \
-    numactl \
-    virtualenv
-
-ENV SIGOPT_PROJECT=. \
-    MPI4JAX_USE_SYCL_MPI=1 \
-    MPI4PY_BUILD_BACKEND=scikit-build-core
-
-WORKDIR /
-COPY multinode/requirements.txt requirements.txt
-
-
-RUN python -m pip install --no-cache-dir -r requirements.txt && \
-    rm -rf requirements.txt
-
-ENV LD_LIBRARY_PATH="/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}"
-
-RUN apt-get install -y --no-install-recommends --fix-missing \
-    openssh-client \
-    openssh-server && \
-    rm /etc/ssh/ssh_host_*_key \
-    /etc/ssh/ssh_host_*_key.pub && \
-    apt-get clean && \
-    rm -rf /var/lib/apt/lists/*
-
-RUN mkdir -p /var/run/sshd
-
-ARG PYTHON_VERSION
-
-COPY multinode/generate_ssh_keys.sh /generate_ssh_keys.sh
-
-# modify generate_ssh_keys to be a helper script
-# print how to use helper script on bash startup
-# Avoids loop for further execution of the startup file
-ARG PACKAGE_OPTION=pip
-ARG PYPATH="/usr/local/lib/python${PYTHON_VERSION}/dist-packages"
-RUN cat '/generate_ssh_keys.sh' >> ~/.startup && \
-    rm -rf /generate_ssh_keys.sh
-
-COPY multinode/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
-COPY multinode/sshd_config /etc/ssh/sshd_config
-COPY multinode/ssh_config /etc/ssh/ssh_config
-
-RUN mkdir -p /licensing
-
-ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]
diff --git a/jax/docker-compose.yaml b/jax/docker-compose.yaml
index 1cec9fab..8b740a6b 100644
--- a/jax/docker-compose.yaml
+++ b/jax/docker-compose.yaml
@@ -12,6 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+include:
+  - path:
+      - ../python/docker-compose.yaml
 services:
   jax-base:
     build:
@@ -19,8 +22,22 @@ services:
         http_proxy: ${http_proxy}
         https_proxy: ${https_proxy}
         no_proxy: ""
-        BASE_IMAGE_NAME: ${BASE_IMAGE_NAME:-intel/oneapi-basekit}
-        BASE_IMAGE_TAG: ${BASE_IMAGE_TAG:-2024.2.0-devel-ubuntu22.04}
+        BASE_IMAGE_NAME: ${BASE_IMAGE_NAME:-ubuntu}
+        BASE_IMAGE_TAG: ${BASE_IMAGE_TAG:-22.04}
+        GITHUB_RUN_NUMBER: ${GITHUB_RUN_NUMBER:-0}
+        MINIFORGE_VERSION: ${MINIFORGE_VERSION:-Linux-x86_64}
+        NO_PROXY: ''
+        PACKAGE_OPTION: ${PACKAGE_OPTION:-pip}
+        PYTHON_VERSION: ${PYTHON_VERSION:-3.10}
+        REGISTRY: ${REGISTRY}
+        REPO: ${REPO}
+        CCL_VER: ${CCL_VER:-2021.13.1-31}
+        DPCPP_VER: ${DPCPP_VER:-2024.2.1-1079}
+        ICD_VER: ${ICD_VER:-24.22.29735.27-914~22.04}
+        MKL_VER: ${MKL_VER:-2024.2.1-103}
+        LEVEL_ZERO_DEV_VER: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
+        LEVEL_ZERO_GPU_VER: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
+        LEVEL_ZERO_VER: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
       context: .
       labels:
         dependency.python: ${PYTHON_VERSION:-3.10}
@@ -51,21 +68,3 @@ services:
     extends: jax-base
     image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
     network_mode: host
-  multinode:
-    build:
-      labels:
-        # dependency.apt.gcc: true
-        # dependency.apt.libgl1-mesa-glx: true
-        # dependency.apt.libglib2: true
-        dependency.apt.python3-dev: true
-        dependency.pip.apt.virtualenv: true
-        dependency.python.pip: multinode/requirements.txt
-        org.opencontainers.base.name: "intel/intel-optimized-xla:${INTEL_XLA_VERSION:-v0.4.0}-base"
-        org.opencontainers.image.title: "Intel® Optimized XLA MultiNode Image"
-        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-multinode
-      target: jax-multinode
-    command: >
-      bash -c "mpirun --version && python -c 'from mpi4py import MPI; import mpi4jax'"
-    extends: jax-base
-    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-inc-${INC_VERSION:-3.0}
-    shm_size: 2gb
diff --git a/jax/multinode/dockerd-entrypoint.sh b/jax/multinode/dockerd-entrypoint.sh
deleted file mode 100755
index ba13c0f9..00000000
--- a/jax/multinode/dockerd-entrypoint.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/bin/bash
-# Copyright (c) 2024 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -e
-set -a
-# shellcheck disable=SC1091
-source "$HOME/.startup"
-set +a
-"$@"
diff --git a/jax/multinode/generate_ssh_keys.sh b/jax/multinode/generate_ssh_keys.sh
deleted file mode 100755
index 0ee61398..00000000
--- a/jax/multinode/generate_ssh_keys.sh
+++ /dev/null
@@ -1,28 +0,0 @@
-#!/usr/bin/env bash
-# Copyright (c) 2023 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# SPDX-License-Identifier: Apache-2.0
-
-function gen_single_key() {
-	ALG_NAME=$1
-	if [[ ! -f /etc/ssh/ssh_host_${ALG_NAME}_key ]]; then
-		ssh-keygen -q -N "" -t "${ALG_NAME}" -f "/etc/ssh/ssh_host_${ALG_NAME}_key"
-	fi
-}
-
-gen_single_key dsa
-gen_single_key rsa
-gen_single_key ecdsa
-gen_single_key ed25519
diff --git a/jax/multinode/requirements.txt b/jax/multinode/requirements.txt
deleted file mode 100644
index 47bd5f70..00000000
--- a/jax/multinode/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-neural-compressor==3.0
-# mpi4py>=3.1.0
-mpi4jax>=0.5.3
diff --git a/jax/multinode/ssh_config b/jax/multinode/ssh_config
deleted file mode 100644
index 9ac73017..00000000
--- a/jax/multinode/ssh_config
+++ /dev/null
@@ -1,4 +0,0 @@
-Host *
-    Port 3022
-    IdentityFile ~/.ssh/id_rsa
-    StrictHostKeyChecking no
diff --git a/jax/multinode/sshd_config b/jax/multinode/sshd_config
deleted file mode 100644
index 4796a48a..00000000
--- a/jax/multinode/sshd_config
+++ /dev/null
@@ -1,12 +0,0 @@
-HostKey /etc/ssh/ssh_host_dsa_key
-HostKey /etc/ssh/ssh_host_rsa_key
-HostKey /etc/ssh/ssh_host_ecdsa_key
-HostKey /etc/ssh/ssh_host_ed25519_key
-AuthorizedKeysFile /etc/ssh/authorized_keys
-## Enable DEBUG log. You can ignore this but this may help you debug any issue while enabling SSHD for the first time
-LogLevel DEBUG3
-Port 3022
-UsePAM yes
-Subsystem       sftp    /usr/lib/openssh/sftp-server
-# https://ubuntu.com/security/CVE-2024-6387
-LoginGraceTime 0
diff --git a/jax/tests/multinode-ex.py b/jax/tests/multinode-ex.py
deleted file mode 100644
index 2a05d913..00000000
--- a/jax/tests/multinode-ex.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright (c) 2024 Intel Corporation
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# pylint: skip-file
-
-import jax
-import jax.numpy as jnp
-import mpi4jax
-from mpi4py import MPI
-
-comm = MPI.COMM_WORLD
-rank = comm.Get_rank()
-
-
-@jax.jit
-def foo(arr):
-    arr = arr + rank
-    arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
-    return arr_sum
-
-
-a = jnp.zeros((3, 3))
-result = foo(a)
-
-if rank == 0:
-    print(result)
diff --git a/jax/tests/tests.yaml b/jax/tests/tests.yaml
index 8607cff0..5ce0a73e 100644
--- a/jax/tests/tests.yaml
+++ b/jax/tests/tests.yaml
@@ -13,23 +13,17 @@
 # limitations under the License.
 
 ---
-jax-import:
+jax-import-${PACKAGE_OPTION:-pip}:
   img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
-  cmd: python -c 'import jax; print("Jax Version:", jax.__version__)'
-jax-import-jupyter:
+  cmd: python -c 'import jax; print("Jax Version:", jax.__version__); print(jax.devices())'
+  device: ["/dev/dri"]
+jax-import-jupyter-${PACKAGE_OPTION:-pip}:
   img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
   cmd: sh -c "python -m jupyter --version"
-jax-xpu-example:
+jax-xpu-example-${PACKAGE_OPTION:-pip}:
   img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
   cmd: python /tests/example.py
   device: ["/dev/dri"]
   volumes:
     - src: $PWD/jax/tests
       dst: /tests
-jax-multinode-example:
-  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-inc-${INC_VERSION:-3.0}
-  cmd: python /tests/multinode-ex.py
-  device: ["/dev/dri"]
-  volumes:
-    - src: $PWD/jax/tests
-      dst: /tests

From 6f3c5be5ef6fac7a283169c9a5a3b8b7357c0766 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 13:17:35 -0700
Subject: [PATCH 06/12] update docs

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 docs/roadmap.md         |   2 +-
 docs/scripts/readmes.py |   1 +
 jax/README.md           | 101 ++++++++++++++--------------------------
 jax/docker-compose.yaml |  35 ++++++++++----
 4 files changed, 63 insertions(+), 76 deletions(-)

diff --git a/docs/roadmap.md b/docs/roadmap.md
index 8e22e7c8..018808b6 100644
--- a/docs/roadmap.md
+++ b/docs/roadmap.md
@@ -10,7 +10,7 @@
 
 - Granite Rapids Support
 - CLS Support
-- Intel Developer Cloud Support
+- Intel Tiber Developer Cloud Support
 - AI Tools 2024.3/2025.0 Support
 
 ## Q4'24
diff --git a/docs/scripts/readmes.py b/docs/scripts/readmes.py
index 3e7d5e09..806b2018 100644
--- a/docs/scripts/readmes.py
+++ b/docs/scripts/readmes.py
@@ -17,6 +17,7 @@
 
 readmes = [
     "classical-ml/README.md",
+    "jax/README.md"
     "preset/README.md",
     "python/README.md",
     "pytorch/README.md",
diff --git a/jax/README.md b/jax/README.md
index 97cc50c2..9dd6540e 100644
--- a/jax/README.md
+++ b/jax/README.md
@@ -1,30 +1,21 @@
-# Intel® Optimized ML
+# Intel® Optimized OpenXLA\*
 
-[Intel® Extension for Scikit-learn*] enhances the performance of [Scikit-learn*] by accelerating the training and inference of machine learning models on Intel® hardware.
-
-[XGBoost*] is an optimized distributed gradient boosting library designed to be highly efficient, flexible and portable.
+Transformable numerical computing at scale combined with [Intel® Extension for OpenXLA\*], which includes a PJRT plugin implementation to seamlessly runs [JAX\*] models on Intel GPUs.
 
 ## Images
 
-The images below include [Intel® Extension for Scikit-learn*] and [XGBoost*].
+The images below include [JAX\*] and [Intel® Extension for OpenXLA\*].
 
-| Tag(s)                                            | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
-| ------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
-| `2024.6.0-pip-base`, `latest`                     | [v2024.6.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.5.0-pip-base`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.3.0-pip-base`                               | [v2024.3.0]    | [v1.4.2]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `2024.2.0-xgboost-2.0.3-pip-base`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `scikit-learning-2024.0.0-xgboost-2.0.2-pip-base` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+| Tag(s)                     | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| -------------------------- | --------- | ----------------- | -------- | --------------- |
+| `0.4.0-pip-base`, `latest` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
 The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
 
-| Tag(s)                                               | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
-| ---------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
-| `2024.6.0-pip-jupyter`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
-| `2024.5.0-pip-jupyter`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.3.0-pip-jupyter`                               | [v2024.3.0]    | [v1.4.2]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `2024.2.0-xgboost-2.0.3-pip-jupyter`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `scikit-learning-2024.0.0-xgboost-2.0.2-pip-jupyter` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+| Tag(s)              | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| ------------------- | --------- | ----------------- | -------- | --------------- |
+| `0.4.0-pip-jupyter` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
+
 
 ### Run the Jupyter Container
 
@@ -34,7 +25,7 @@ docker run -it --rm \
     --net=host \
     -v $PWD/workspace:/workspace \
     -w /workspace \
-    intel/intel-optimized-ml:2024.2.0-xgboost-2.0.3-pip-jupyter
+    intel/intel-optimized-xla:0.4.0-pip-jupyter
 ```
 
 After running the command above, copy the URL (something like `http://127.0.0.1:$PORT/?token=***`) into your browser to access the notebook server.
@@ -43,40 +34,33 @@ After running the command above, copy the URL (something like `http://127.0.0.1:
 
 The images below include [Intel® Distribution for Python*]:
 
-| Tag(s)                                            | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
-| ------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
-| `2024.6.0-idp-base`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
-| `2024.5.0-idp-base`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.3.0-idp-base`                               | [v2024.3.0]    | [v1.4.1]     | [v2.1.0] | [v0.4.0]        |
-| `2024.2.0-xgboost-2.0.3-idp-base`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `scikit-learning-2024.0.0-xgboost-2.0.2-idp-base` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+| Tag(s)           | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| ---------------- | --------- | ----------------- | -------- | --------------- |
+| `0.4.0-idp-base` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
+
 
 The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
 
-| Tag(s)                                               | Intel SKLearn  | Scikit-learn | XGBoost  | Dockerfile      |
-| ---------------------------------------------------- | -------------- | ------------ | -------- | --------------- |
-| `2024.6.0-idp-jupyter`                               | [v2024.6.0]    | [v1.5.1]     | [v2.1.1] | [v0.4.0]        |
-| `2024.5.0-idp-jupyter`                               | [v2024.5.0]    | [v1.5.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.3.0-idp-jupyter`                               | [v2024.3.0]    | [v1.4.0]     | [v2.1.0] | [v0.4.0]        |
-| `2024.2.0-xgboost-2.0.3-idp-jupyter`                 | [v2024.2.0]    | [v1.4.1]     | [v2.0.3] | [v0.4.0-Beta]   |
-| `scikit-learning-2024.0.0-xgboost-2.0.2-idp-jupyter` | [v2024.0.0]    | [v1.3.2]     | [v2.0.2] | [v0.3.4]        |
+| Tag(s)              | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| ------------------- | --------- | ----------------- | -------- | --------------- |
+| `0.4.0-idp-jupyter` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
 ## Build from Source
 
 To build the images from source, clone the [AI Containers](https://github.com/intel/ai-containers) repository, follow the main `README.md` file to setup your environment, and run the following command:
 
 ```bash
-cd classical-ml
-docker compose build ml-base
-docker compose run ml-base
+cd jax
+docker compose build jax-base
+docker compose run -it jax-base
 ```
 
 You can find the list of services below for each container in the group:
 
-| Service Name | Description                                                         |
-| ------------ | ------------------------------------------------------------------- |
-| `ml-base`    | Base image with [Intel® Extension for Scikit-learn*] and [XGBoost*] |
-| `jupyter`    | Adds Jupyter Notebook server                                        |
+| Service Name | Description                                     |
+| ------------ | ----------------------------------------------- |
+| `jax-base`   | Base image with [Intel® Extension for OpenXLA\*] |
+| `jupyter`    | Adds Jupyter Notebook server                    |
 
 ## License
 
@@ -90,28 +74,15 @@ It is the image user's responsibility to ensure that any use of The images below
 
 <!--Below are links used in these document. They are not rendered: -->
 
-[Intel® Extension for Scikit-learn*]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/scikit-learn.html
 [Intel® Distribution for Python]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/distribution-for-python.html#gs.9bos9m
-[Scikit-learn*]: https://scikit-learn.org/stable/
-[XGBoost*]: https://github.com/dmlc/xgboost
-
-[v2024.6.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.6.0
-[v2024.5.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.5.0
-[v2024.3.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.3.0
-[v2024.2.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.2.0
-[v2024.0.0]: https://github.com/intel/scikit-learn-intelex/releases/tag/2024.0.0
-
-[v1.5.1]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.5.1
-[v1.5.0]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.5.0
-[v1.4.2]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.4.2
-[v1.4.1]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.4.1
-[v1.3.2]: https://github.com/scikit-learn/scikit-learn/releases/tag/1.3.2
-
-[v2.1.1]: https://github.com/dmlc/xgboost/releases/tag/v2.1.1
-[v2.1.0]: https://github.com/dmlc/xgboost/releases/tag/v2.1.0
-[v2.0.3]: https://github.com/dmlc/xgboost/releases/tag/v2.0.3
-[v2.0.2]: https://github.com/dmlc/xgboost/releases/tag/v2.0.2
-
-[v0.4.0]: https://github.com/intel/ai-containers/blob/v0.4.0/classical-ml/Dockerfile
-[v0.4.0-Beta]: https://github.com/intel/ai-containers/blob/v0.4.0-Beta/classical-ml/Dockerfile
-[v0.3.4]: https://github.com/intel/ai-containers/blob/v0.3.4/classical-ml/Dockerfile
+[Intel® Extension for OpenXLA\*]: https://github.com/intel/intel-extension-for-openxla
+[JAX\*]: https://github.com/google/jax
+[Flax]: https://github.com/google/flax
+
+[v0.4.32]: https://github.com/google/jax/releases/tag/jax-v0.4.32
+
+[v0.4.0-jax]: https://github.com/intel/intel-extension-for-openxla/releases/tag/0.4.0
+
+[v0.9.0]: https://github.com/google/Flax/releases/tag/v0.9.0
+
+[v0.4.0]: https://github.com/intel/ai-containers/blob/v0.4.0/jax/Dockerfile
diff --git a/jax/docker-compose.yaml b/jax/docker-compose.yaml
index 8b740a6b..5590e1da 100644
--- a/jax/docker-compose.yaml
+++ b/jax/docker-compose.yaml
@@ -24,33 +24,48 @@ services:
         no_proxy: ""
         BASE_IMAGE_NAME: ${BASE_IMAGE_NAME:-ubuntu}
         BASE_IMAGE_TAG: ${BASE_IMAGE_TAG:-22.04}
+        CCL_VER: ${CCL_VER:-2021.13.1-31}
+        DPCPP_VER: ${DPCPP_VER:-2024.2.1-1079}
         GITHUB_RUN_NUMBER: ${GITHUB_RUN_NUMBER:-0}
+        ICD_VER: ${ICD_VER:-24.22.29735.27-914~22.04}
+        LEVEL_ZERO_DEV_VER: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
+        LEVEL_ZERO_GPU_VER: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
+        LEVEL_ZERO_VER: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
         MINIFORGE_VERSION: ${MINIFORGE_VERSION:-Linux-x86_64}
+        MKL_VER: ${MKL_VER:-2024.2.1-103}
         NO_PROXY: ''
         PACKAGE_OPTION: ${PACKAGE_OPTION:-pip}
         PYTHON_VERSION: ${PYTHON_VERSION:-3.10}
         REGISTRY: ${REGISTRY}
         REPO: ${REPO}
-        CCL_VER: ${CCL_VER:-2021.13.1-31}
-        DPCPP_VER: ${DPCPP_VER:-2024.2.1-1079}
-        ICD_VER: ${ICD_VER:-24.22.29735.27-914~22.04}
-        MKL_VER: ${MKL_VER:-2024.2.1-103}
-        LEVEL_ZERO_DEV_VER: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
-        LEVEL_ZERO_GPU_VER: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
-        LEVEL_ZERO_VER: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
       context: .
       labels:
         dependency.python: ${PYTHON_VERSION:-3.10}
+        dependency.apt.build-essential: true
+        dependency.apt.clinfo: true
+        dependency.apt.git: true
+        dependency.apt.gnupg2: true
+        dependency.apt.gpg-agent: true
+        dependency.apt.intel-level-zero-gpu: ${LEVEL_ZERO_GPU_VER:-1.3.29735.27-914~22.04}
+        dependency.apt.intel-oneapi-runtime-ccl: ${CCL_VER:-2021.13.1-31}
+        dependency.apt.intel-oneapi-runtime-dpcpp-cpp: ${DPCPP_VER:-2024.2.1-1079}
+        dependency.apt.intel-oneapi-runtime-mkl: ${MKL_VER:-2024.2.1-103}
+        dependency.apt.intel-opencl-icd: ${ICD_VER:-23.43.27642.40-803~22.04}
+        dependency.apt.level-zero: ${LEVEL_ZERO_VER:-1.17.6-914~22.04}
+        dependency.apt.level-zero-dev: ${LEVEL_ZERO_DEV_VER:-1.17.6-914~22.04}
+        dependency.apt.rsync: true
+        dependency.apt.unzip: true
+        dependency.idp.pip: false
         dependency.python.pip: requirements.txt
         docs: jax
         org.opencontainers.base.name: "intel/python:3.10-core"
         org.opencontainers.image.name: "intel/intel-optimized-xla"
         org.opencontainers.image.title: "Intel® Optimized XLA Base Image"
-        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-base
+        org.opencontainers.image.version: ${INTEL_XLA_VERSION:-v0.4.0}-${PACKAGE_OPTION:-pip}-base
       target: jax-base
     command: >
       bash -c "python -c 'import jax; print(\"Jax Version:\", jax.__version__)'"
-    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
     pull_policy: always
   jupyter:
     build:
@@ -66,5 +81,5 @@ services:
       http_proxy: ${http_proxy}
       https_proxy: ${https_proxy}
     extends: jax-base
-    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
+    image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
     network_mode: host

From 822b4aa2083a5bfc1305d25a19006f57a23b9914 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Fri, 13 Sep 2024 20:18:31 +0000
Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks

---
 docs/scripts/readmes.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/docs/scripts/readmes.py b/docs/scripts/readmes.py
index 806b2018..fc587264 100644
--- a/docs/scripts/readmes.py
+++ b/docs/scripts/readmes.py
@@ -17,8 +17,7 @@
 
 readmes = [
     "classical-ml/README.md",
-    "jax/README.md"
-    "preset/README.md",
+    "jax/README.md" "preset/README.md",
     "python/README.md",
     "pytorch/README.md",
     "tensorflow/README.md",

From 73f88a5ae3a4aa0c3586aa8ceab25b56da148432 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 13:27:59 -0700
Subject: [PATCH 08/12] fix readme list typo

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 docs/scripts/readmes.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/docs/scripts/readmes.py b/docs/scripts/readmes.py
index fc587264..8eb2553b 100644
--- a/docs/scripts/readmes.py
+++ b/docs/scripts/readmes.py
@@ -17,7 +17,8 @@
 
 readmes = [
     "classical-ml/README.md",
-    "jax/README.md" "preset/README.md",
+    "jax/README.md",
+    "preset/README.md",
     "python/README.md",
     "pytorch/README.md",
     "tensorflow/README.md",

From 9ef6f752434d1dc09de909b8209da59e3b7fc868 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 13:29:14 -0700
Subject: [PATCH 09/12] fix markdown lint

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/README.md | 2 --
 1 file changed, 2 deletions(-)

diff --git a/jax/README.md b/jax/README.md
index 9dd6540e..f8407f43 100644
--- a/jax/README.md
+++ b/jax/README.md
@@ -16,7 +16,6 @@ The images below additionally include [Jupyter Notebook](https://jupyter.org/) s
 | ------------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-pip-jupyter` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
-
 ### Run the Jupyter Container
 
 ```bash
@@ -38,7 +37,6 @@ The images below include [Intel® Distribution for Python*]:
 | ---------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-idp-base` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
-
 The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
 
 | Tag(s)              | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |

From 0d5da35883e8f66849dc8e3f49487342075905b3 Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 13:36:37 -0700
Subject: [PATCH 10/12] fix link typo

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/README.md | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/jax/README.md b/jax/README.md
index f8407f43..67ea81ed 100644
--- a/jax/README.md
+++ b/jax/README.md
@@ -6,13 +6,13 @@ Transformable numerical computing at scale combined with [Intel® Extension for
 
 The images below include [JAX\*] and [Intel® Extension for OpenXLA\*].
 
-| Tag(s)                     | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| Tag(s)                     | [JAX\*]   | [Intel® Extension for OpenXLA\*] | [Flax]   | Dockerfile      |
 | -------------------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-pip-base`, `latest` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
 The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
 
-| Tag(s)              | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| Tag(s)              | [JAX\*]   | [Intel® Extension for OpenXLA\*] | [Flax]   | Dockerfile      |
 | ------------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-pip-jupyter` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
@@ -33,13 +33,13 @@ After running the command above, copy the URL (something like `http://127.0.0.1:
 
 The images below include [Intel® Distribution for Python*]:
 
-| Tag(s)           | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| Tag(s)           | [JAX\*]   | [Intel® Extension for OpenXLA\*] | [Flax]   | Dockerfile      |
 | ---------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-idp-base` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
 The images below additionally include [Jupyter Notebook](https://jupyter.org/) server:
 
-| Tag(s)              | [JAX\*]   | [Intel OpenXLA\*] | [Flax]   | Dockerfile      |
+| Tag(s)              | [JAX\*]   | [Intel® Extension for OpenXLA\*] | [Flax]   | Dockerfile      |
 | ------------------- | --------- | ----------------- | -------- | --------------- |
 | `0.4.0-idp-jupyter` | [v0.4.32] | [v0.4.0-jax]      | [v0.9.0] | [v0.4.0]        |
 
@@ -72,7 +72,7 @@ It is the image user's responsibility to ensure that any use of The images below
 
 <!--Below are links used in these document. They are not rendered: -->
 
-[Intel® Distribution for Python]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/distribution-for-python.html#gs.9bos9m
+[Intel® Distribution for Python*]: https://www.intel.com/content/www/us/en/developer/tools/oneapi/distribution-for-python.html#gs.9bos9m
 [Intel® Extension for OpenXLA\*]: https://github.com/intel/intel-extension-for-openxla
 [JAX\*]: https://github.com/google/jax
 [Flax]: https://github.com/google/flax

From a8e5f0592911e52548305f69078c3e4fb4a3420e Mon Sep 17 00:00:00 2001
From: tylertitsworth <tyler.titsworth@intel.com>
Date: Fri, 13 Sep 2024 14:41:28 -0700
Subject: [PATCH 11/12] add depends_on step

Signed-off-by: tylertitsworth <tyler.titsworth@intel.com>
---
 jax/docker-compose.yaml | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/jax/docker-compose.yaml b/jax/docker-compose.yaml
index 5590e1da..e2c47d63 100644
--- a/jax/docker-compose.yaml
+++ b/jax/docker-compose.yaml
@@ -65,6 +65,8 @@ services:
       target: jax-base
     command: >
       bash -c "python -c 'import jax; print(\"Jax Version:\", jax.__version__)'"
+    depends_on:
+      - ${PACKAGE_OPTION:-pip}
     image: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
     pull_policy: always
   jupyter:

From 1298d5bafd50c86614f50e1d44f255c8d484bbab Mon Sep 17 00:00:00 2001
From: Tyler Titsworth <tyler.titsworth@intel.com>
Date: Wed, 18 Sep 2024 13:42:41 -0700
Subject: [PATCH 12/12] Update tests.yaml

Signed-off-by: Tyler Titsworth <tyler.titsworth@intel.com>
---
 jax/tests/tests.yaml | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/jax/tests/tests.yaml b/jax/tests/tests.yaml
index 5ce0a73e..419dbf3f 100644
--- a/jax/tests/tests.yaml
+++ b/jax/tests/tests.yaml
@@ -14,14 +14,14 @@
 
 ---
 jax-import-${PACKAGE_OPTION:-pip}:
-  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
   cmd: python -c 'import jax; print("Jax Version:", jax.__version__); print(jax.devices())'
   device: ["/dev/dri"]
 jax-import-jupyter-${PACKAGE_OPTION:-pip}:
-  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-jupyter
   cmd: sh -c "python -m jupyter --version"
 jax-xpu-example-${PACKAGE_OPTION:-pip}:
-  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
+  img: ${REGISTRY}/${REPO}:b-${GITHUB_RUN_NUMBER:-0}-${BASE_IMAGE_NAME:-ubuntu}-${BASE_IMAGE_TAG:-22.04}-${PACKAGE_OPTION:-pip}-py${PYTHON_VERSION:-3.10}-xla-${INTEL_XLA_VERSION:-v0.4.0}-base
   cmd: python /tests/example.py
   device: ["/dev/dri"]
   volumes: