Skip to content

Commit 5017d06

Browse files
AdamLoulyAdam LoulyJingyaHuang
authored
Refactoring FSDP. (#1586)
* refactor fsdp * add trainer * remove hidden layers * update dockerfile --------- Co-authored-by: Adam Louly <adamlouly@microsoft.com@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: JingyaHuang <huang_jingya@outlook.com>
1 parent 1a807fc commit 5017d06

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

examples/onnxruntime/training/docker/Dockerfile-ort1.16.1-cu118 examples/onnxruntime/training/docker/Dockerfile-ort1.16.3-cu118

+4-1
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,15 @@ RUN $PYTHON_EXE -m pip install onnx ninja
6565
RUN $PYTHON_EXE -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} -f https://download.pytorch.org/whl/${TORCH_CUDA_VERSION}
6666

6767
# ORT Module
68-
RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.1 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
68+
RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.3 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
6969
RUN $PYTHON_EXE -m pip install torch-ort
7070
ENV TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX"
7171
RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
7272
RUN $PYTHON_EXE -m torch_ort.configure
7373

74+
# https://github.com/vllm-project/vllm/issues/1726
75+
RUN pip uninstall nvidia-nccl-cu12 -y
76+
7477
WORKDIR .
7578

7679
CMD ["/bin/bash"]

optimum/onnxruntime/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _inner_training_loop(
455455
else:
456456
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
457457

458-
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
458+
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
459459

460460
# Wrap the model with `ORTModule`
461461
logger.info("Wrap ORTModule for ONNX Runtime training.")
@@ -883,7 +883,7 @@ def _wrap_model(self, model, training=True, dataloader=None):
883883
return model
884884

885885
# Distributed training using PyTorch FSDP
886-
if self.fsdp is not None:
886+
if self.is_fsdp_xla_enabled:
887887
try:
888888
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
889889
from torch_xla.distributed.fsdp import checkpoint_module

tests/onnxruntime/docker/Dockerfile_onnxruntime_trainer

+4-1
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,15 @@ RUN $PYTHON_EXE -m pip install onnx ninja
6565
RUN $PYTHON_EXE -m pip install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION} -f https://download.pytorch.org/whl/${TORCH_CUDA_VERSION}
6666

6767
# ORT Module
68-
RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.1 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
68+
RUN $PYTHON_EXE -m pip install onnxruntime-training==1.16.3 -f https://download.onnxruntime.ai/onnxruntime_stable_cu118.html
6969
RUN $PYTHON_EXE -m pip install torch-ort
7070
ENV TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX"
7171
RUN $PYTHON_EXE -m pip install --upgrade protobuf==3.20.2
7272
RUN $PYTHON_EXE -m torch_ort.configure
7373

74+
# https://github.com/vllm-project/vllm/issues/1726
75+
RUN pip uninstall nvidia-nccl-cu12 -y
76+
7477
# Install Optimum
7578
COPY . /workspace/optimum
7679
RUN pip install /workspace/optimum[tests]

0 commit comments

Comments
 (0)