|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "1e973d1b-c6d0-48a5-a774-0f114101e81e", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Getting started with PyTorch on Intel® Gaudi.\n", |
| 9 | + "\n", |
| 10 | + "This notebook is to help you get started quickly using the Intel® Gaudi accelerator in this container. A simple MNIST model is trained on the Gaudi acclerator. You can tune some of the parameters below to change configuration of the training. For more information and reference please refer to the official documentation of [Intel® Gaudi acclerator](https://docs.habana.ai/en/latest/index.html)." |
| 11 | + ] |
| 12 | + }, |
| 13 | + { |
| 14 | + "cell_type": "markdown", |
| 15 | + "id": "7eaacf55-bea2-43be-bb48-163848db1a30", |
| 16 | + "metadata": { |
| 17 | + "tags": [] |
| 18 | + }, |
| 19 | + "source": [ |
| 20 | + "### Setup modes for training\n", |
| 21 | + "\n", |
| 22 | + "1. lazy_mode: Set to True(False) to enable(disable) lazy mode.\n", |
| 23 | + "2. enable_amp: Set to True(False) to enable Automatic Mixed Precision.\n", |
| 24 | + "3. epochs: Number of epochs for training\n", |
| 25 | + "4. lr: Learning rate for training\n", |
| 26 | + "5. batch_size: Number of samples in a batch\n", |
| 27 | + "6. milestones: Milestone epochs for the stepLR scheduler." |
| 28 | + ] |
| 29 | + }, |
| 30 | + { |
| 31 | + "cell_type": "code", |
| 32 | + "execution_count": null, |
| 33 | + "id": "5e7cf831-6fe6-46ed-a6fd-f2651cc226af", |
| 34 | + "metadata": { |
| 35 | + "tags": [] |
| 36 | + }, |
| 37 | + "outputs": [], |
| 38 | + "source": [ |
| 39 | + "lazy_mode = False\n", |
| 40 | + "enable_amp = False\n", |
| 41 | + "epochs = 20\n", |
| 42 | + "batch_size = 128\n", |
| 43 | + "lr = 0.01\n", |
| 44 | + "milestones = [10,15]" |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "code", |
| 49 | + "execution_count": null, |
| 50 | + "id": "cee8ad90-c52d-4a50-876f-ce0762cb1b62", |
| 51 | + "metadata": { |
| 52 | + "tags": [] |
| 53 | + }, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "import os\n", |
| 57 | + "os.environ['HABANA_LOGS']='/opt/app-root/logs'\n", |
| 58 | + "if lazy_mode:\n", |
| 59 | + " os.environ['PT_HPU_LAZY_MODE'] = '1'\n", |
| 60 | + "else:\n", |
| 61 | + " os.environ['PT_HPU_LAZY_MODE'] = '0'" |
| 62 | + ] |
| 63 | + }, |
| 64 | + { |
| 65 | + "cell_type": "markdown", |
| 66 | + "id": "6eac33d0-2e64-4233-8b3f-40bb7217fef8", |
| 67 | + "metadata": { |
| 68 | + "tags": [] |
| 69 | + }, |
| 70 | + "source": [ |
| 71 | + "### Import packages" |
| 72 | + ] |
| 73 | + }, |
| 74 | + { |
| 75 | + "cell_type": "code", |
| 76 | + "execution_count": null, |
| 77 | + "id": "06ad44ff-9744-4d6f-af90-375e64717b59", |
| 78 | + "metadata": {}, |
| 79 | + "outputs": [], |
| 80 | + "source": [ |
| 81 | + "import torch\n", |
| 82 | + "import torch.nn as nn\n", |
| 83 | + "import torch.optim as optim\n", |
| 84 | + "import torch.nn.functional as F\n", |
| 85 | + "import torchvision\n", |
| 86 | + "import torchvision.transforms as transforms\n", |
| 87 | + "import os\n", |
| 88 | + "\n", |
| 89 | + "# Import Habana Torch Library\n", |
| 90 | + "import habana_frameworks.torch.core as htcore" |
| 91 | + ] |
| 92 | + }, |
| 93 | + { |
| 94 | + "cell_type": "markdown", |
| 95 | + "id": "062de7f3-4561-4af3-a9ed-2c4cfc918f2f", |
| 96 | + "metadata": {}, |
| 97 | + "source": [ |
| 98 | + "### Define Model" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "code", |
| 103 | + "execution_count": null, |
| 104 | + "id": "9df57abb-0b63-4e1c-9d9b-87e74964300e", |
| 105 | + "metadata": {}, |
| 106 | + "outputs": [], |
| 107 | + "source": [ |
| 108 | + "class SimpleModel(nn.Module):\n", |
| 109 | + " def __init__(self):\n", |
| 110 | + " super(SimpleModel, self).__init__()\n", |
| 111 | + "\n", |
| 112 | + " self.fc1 = nn.Linear(784, 256)\n", |
| 113 | + " self.fc2 = nn.Linear(256, 64)\n", |
| 114 | + " self.fc3 = nn.Linear(64, 10)\n", |
| 115 | + "\n", |
| 116 | + " def forward(self, x):\n", |
| 117 | + "\n", |
| 118 | + " out = x.view(-1,28*28)\n", |
| 119 | + " out = F.relu(self.fc1(out))\n", |
| 120 | + " out = F.relu(self.fc2(out))\n", |
| 121 | + " out = self.fc3(out)\n", |
| 122 | + "\n", |
| 123 | + " return out" |
| 124 | + ] |
| 125 | + }, |
| 126 | + { |
| 127 | + "cell_type": "markdown", |
| 128 | + "id": "d899885b-5b4d-4557-a90c-9d507875c2ee", |
| 129 | + "metadata": {}, |
| 130 | + "source": [ |
| 131 | + "### Define training routine" |
| 132 | + ] |
| 133 | + }, |
| 134 | + { |
| 135 | + "cell_type": "code", |
| 136 | + "execution_count": null, |
| 137 | + "id": "7b17e9aa-fa11-4870-a7d4-183b803177ab", |
| 138 | + "metadata": {}, |
| 139 | + "outputs": [], |
| 140 | + "source": [ |
| 141 | + "def train(net,criterion,optimizer,trainloader,device):\n", |
| 142 | + "\n", |
| 143 | + " net.train()\n", |
| 144 | + " if not lazy_mode:\n", |
| 145 | + " net = torch.compile(net,backend=\"hpu_backend\")\n", |
| 146 | + " train_loss = 0.0\n", |
| 147 | + " correct = 0\n", |
| 148 | + " total = 0\n", |
| 149 | + "\n", |
| 150 | + " for batch_idx, (data, targets) in enumerate(trainloader):\n", |
| 151 | + "\n", |
| 152 | + " data, targets = data.to(device), targets.to(device)\n", |
| 153 | + "\n", |
| 154 | + " optimizer.zero_grad()\n", |
| 155 | + " if enable_amp:\n", |
| 156 | + " with torch.autocast(device_type=\"hpu\", dtype=torch.bfloat16):\n", |
| 157 | + " outputs = net(data)\n", |
| 158 | + " loss = criterion(outputs, targets)\n", |
| 159 | + " else:\n", |
| 160 | + " outputs = net(data)\n", |
| 161 | + " loss = criterion(outputs, targets)\n", |
| 162 | + "\n", |
| 163 | + " loss.backward()\n", |
| 164 | + " \n", |
| 165 | + " # API call to trigger execution\n", |
| 166 | + " if lazy_mode:\n", |
| 167 | + " htcore.mark_step()\n", |
| 168 | + " \n", |
| 169 | + " optimizer.step()\n", |
| 170 | + "\n", |
| 171 | + " # API call to trigger execution\n", |
| 172 | + " if lazy_mode:\n", |
| 173 | + " htcore.mark_step()\n", |
| 174 | + "\n", |
| 175 | + " train_loss += loss.item()\n", |
| 176 | + " _, predicted = outputs.max(1)\n", |
| 177 | + " total += targets.size(0)\n", |
| 178 | + " correct += predicted.eq(targets).sum().item()\n", |
| 179 | + "\n", |
| 180 | + " train_loss = train_loss/(batch_idx+1)\n", |
| 181 | + " train_acc = 100.0*(correct/total)\n", |
| 182 | + " print(\"Training loss is {} and training accuracy is {}\".format(train_loss,train_acc))" |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "markdown", |
| 187 | + "id": "b7a22d69-a91f-48e1-8fac-e1cfe68590b7", |
| 188 | + "metadata": {}, |
| 189 | + "source": [ |
| 190 | + "### Define testing routine" |
| 191 | + ] |
| 192 | + }, |
| 193 | + { |
| 194 | + "cell_type": "code", |
| 195 | + "execution_count": null, |
| 196 | + "id": "f9aa379b-b376-4623-9b5c-f778c3d90ce7", |
| 197 | + "metadata": {}, |
| 198 | + "outputs": [], |
| 199 | + "source": [ |
| 200 | + "def test(net,criterion,testloader,device):\n", |
| 201 | + "\n", |
| 202 | + " net.eval()\n", |
| 203 | + " test_loss = 0\n", |
| 204 | + " correct = 0\n", |
| 205 | + " total = 0\n", |
| 206 | + "\n", |
| 207 | + " with torch.no_grad():\n", |
| 208 | + "\n", |
| 209 | + " for batch_idx, (data, targets) in enumerate(testloader):\n", |
| 210 | + "\n", |
| 211 | + " data, targets = data.to(device), targets.to(device)\n", |
| 212 | + " \n", |
| 213 | + " if enable_amp:\n", |
| 214 | + " with torch.autocast(device_type=\"hpu\", dtype=torch.bfloat16):\n", |
| 215 | + " outputs = net(data)\n", |
| 216 | + " loss = criterion(outputs, targets)\n", |
| 217 | + " else:\n", |
| 218 | + " outputs = net(data)\n", |
| 219 | + " loss = criterion(outputs, targets)\n", |
| 220 | + "\n", |
| 221 | + "\n", |
| 222 | + " # API call to trigger execution\n", |
| 223 | + " if lazy_mode:\n", |
| 224 | + " htcore.mark_step()\n", |
| 225 | + "\n", |
| 226 | + " test_loss += loss.item()\n", |
| 227 | + " _, predicted = outputs.max(1)\n", |
| 228 | + " total += targets.size(0)\n", |
| 229 | + " correct += predicted.eq(targets).sum().item()\n", |
| 230 | + "\n", |
| 231 | + " test_loss = test_loss/(batch_idx+1)\n", |
| 232 | + " test_acc = 100.0*(correct/total)\n", |
| 233 | + " print(\"Testing loss is {} and testing accuracy is {}\".format(test_loss,test_acc))" |
| 234 | + ] |
| 235 | + }, |
| 236 | + { |
| 237 | + "cell_type": "markdown", |
| 238 | + "id": "22e76af9-e355-4299-b84d-f34c9a25e76d", |
| 239 | + "metadata": {}, |
| 240 | + "source": [ |
| 241 | + "### Run the main routine to train and test the model" |
| 242 | + ] |
| 243 | + }, |
| 244 | + { |
| 245 | + "cell_type": "code", |
| 246 | + "execution_count": null, |
| 247 | + "id": "1c8ddfb1-d4f7-44b2-aff0-f86f1db8c971", |
| 248 | + "metadata": {}, |
| 249 | + "outputs": [], |
| 250 | + "source": [ |
| 251 | + "load_path = './data'\n", |
| 252 | + "save_path = './checkpoints'\n", |
| 253 | + "\n", |
| 254 | + "if(not os.path.exists(save_path)):\n", |
| 255 | + " os.makedirs(save_path)\n", |
| 256 | + "\n", |
| 257 | + "# Target the Gaudi HPU device\n", |
| 258 | + "device = torch.device(\"hpu\")\n", |
| 259 | + "\n", |
| 260 | + "# Data\n", |
| 261 | + "transform = transforms.Compose([\n", |
| 262 | + " transforms.ToTensor(),\n", |
| 263 | + "])\n", |
| 264 | + "\n", |
| 265 | + "trainset = torchvision.datasets.MNIST(root=load_path, train=True,\n", |
| 266 | + " download=True, transform=transform)\n", |
| 267 | + "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", |
| 268 | + " shuffle=True, num_workers=2)\n", |
| 269 | + "testset = torchvision.datasets.MNIST(root=load_path, train=False,\n", |
| 270 | + " download=True, transform=transform)\n", |
| 271 | + "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", |
| 272 | + " shuffle=False, num_workers=2)\n", |
| 273 | + "\n", |
| 274 | + "net = SimpleModel()\n", |
| 275 | + "net.to(device)\n", |
| 276 | + "\n", |
| 277 | + "criterion = nn.CrossEntropyLoss()\n", |
| 278 | + "optimizer = optim.SGD(net.parameters(), lr=lr,\n", |
| 279 | + " momentum=0.9, weight_decay=5e-4)\n", |
| 280 | + "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)\n", |
| 281 | + "\n", |
| 282 | + "for epoch in range(1, epochs+1):\n", |
| 283 | + " print(\"=====================================================================\")\n", |
| 284 | + " print(\"Epoch : {}\".format(epoch))\n", |
| 285 | + " train(net,criterion,optimizer,trainloader,device)\n", |
| 286 | + " test(net,criterion,testloader,device)\n", |
| 287 | + "\n", |
| 288 | + " torch.save(net.state_dict(), os.path.join(save_path,'epoch_{}.pth'.format(epoch)))\n", |
| 289 | + "\n", |
| 290 | + " scheduler.step()" |
| 291 | + ] |
| 292 | + } |
| 293 | + ], |
| 294 | + "metadata": { |
| 295 | + "kernelspec": { |
| 296 | + "display_name": "Python 3.10", |
| 297 | + "language": "python", |
| 298 | + "name": "python3" |
| 299 | + }, |
| 300 | + "language_info": { |
| 301 | + "codemirror_mode": { |
| 302 | + "name": "ipython", |
| 303 | + "version": 3 |
| 304 | + }, |
| 305 | + "file_extension": ".py", |
| 306 | + "mimetype": "text/x-python", |
| 307 | + "name": "python", |
| 308 | + "nbconvert_exporter": "python", |
| 309 | + "pygments_lexer": "ipython3", |
| 310 | + "version": "3.10.14" |
| 311 | + } |
| 312 | + }, |
| 313 | + "nbformat": 4, |
| 314 | + "nbformat_minor": 5 |
| 315 | +} |
0 commit comments