From 23289fcdcc19d50acc71a00578cc7bbae54e2d2d Mon Sep 17 00:00:00 2001 From: Ramanish Singh <48493179+ramanishsingh@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:45:39 -0800 Subject: [PATCH] Add examples for HF datasets (#1371) * intial_commit * more examples * run precommit * hf mnist * add complete MNIST example * remove old py file * add imdb bert * update mnist example * update hf_bert example * add mds example * update bert example * update function name * run precommit * update bert example * update mnist notebook * update mds * delete ipynb ckpts * remove mds and simplify examples * fix some typos * simplify and remove test train mentiond * remove headings * add titles * fix typo * update url --- examples/nodes/hf_datasets_nodes_mnist.ipynb | 182 +++++++++++++++++++ examples/nodes/hf_imdb_bert.ipynb | 161 ++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 examples/nodes/hf_datasets_nodes_mnist.ipynb create mode 100644 examples/nodes/hf_imdb_bert.ipynb diff --git a/examples/nodes/hf_datasets_nodes_mnist.ipynb b/examples/nodes/hf_datasets_nodes_mnist.ipynb new file mode 100644 index 000000000..92a478ad8 --- /dev/null +++ b/examples/nodes/hf_datasets_nodes_mnist.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1cb04783-4c83-43bc-91c6-b980d836de34", + "metadata": {}, + "source": [ + "### Loading and processing the MNIST dataset\n", + "In this example, we will load the MNIST dataset from Hugging Face, \n", + "use `torchdata.nodes` to process it and generate training batches." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "93026478-6dbd-4ac0-8507-360a3a2000c5", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n", + "dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n", + "\n", + "# Getting the train dataset\n", + "dataset = dataset[\"train\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b0c46c4b-0194-4127-a218-e24ec54a3149", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n", + "\n", + "torch.manual_seed(42)\n", + "\n", + "# Defining samplers\n", + "# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n", + "sampler = RandomSampler(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b", + "metadata": {}, + "outputs": [], + "source": [ + "# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n", + "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", + "\n", + "# All torchdata.nodes.BaseNode implementations are Iterators.\n", + "# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", + "#\n", + "# Under the hood, MapStyleWrapper just does:\n", + "# > node = IterableWrapper(sampler)\n", + "# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", + "\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "# Now we want to transform the raw inputs. We can just use another Mapper with\n", + "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", + "# threads (or processes) to parallelize this work and have it run in the background\n", + "# We need a mapper function to convert a dtype and also normalize\n", + "def map_fn(item):\n", + " image = item[\"image\"].to(torch.float32)/255\n", + " label = item[\"label\"]\n", + "\n", + " return {\"image\":image, \"label\":label}\n", + " \n", + "node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n", + "\n", + "\n", + "# Hyperparameters\n", + "batch_size = 2 \n", + "\n", + "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", + "# to stack the tensor. We use torch.utils.data.default_collate for this\n", + "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", + "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", + "\n", + "# we can optionally apply pin_memory to the batches\n", + "if torch.cuda.is_available():\n", + " node = PinMemory(node)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " ...,\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n", + "There are 2 samples in this batch\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n", + "# Let us look at one batch \n", + "import matplotlib.pyplot as plt\n", + "fig, axs = plt.subplots(2, figsize=(8, 6))\n", + "\n", + "batch = next(iter(loader))\n", + " \n", + "\n", + "print(batch)\n", + "print(f\"There are {len(batch)} samples in this batch\")\n", + "\n", + "# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n", + "# The value of key \"image\" is a stacked tensor of images in the batch\n", + "# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n", + "images = batch[\"image\"]\n", + "labels = batch[\"label\"]\n", + "\n", + "#let's also display the two items\n", + "for i in range(len(images)):\n", + " axs[i].imshow(images[i].squeeze(), cmap='gray')\n", + " axs[i].set_title(f\"Label: {labels[i]}\") \n", + " axs[i].set_axis_off()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/nodes/hf_imdb_bert.ipynb b/examples/nodes/hf_imdb_bert.ipynb new file mode 100644 index 000000000..dd13cad59 --- /dev/null +++ b/examples/nodes/hf_imdb_bert.ipynb @@ -0,0 +1,161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d8513771-36ac-4d03-b890-35108bce2211", + "metadata": {}, + "source": [ + "### Loading and processing IMDB movie review dataset\n", + "In this example, we will load the IMDB dataset from Hugging Face, \n", + "use `torchdata.nodes` to process it and generate training batches." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "eb3b507c-2ad1-410d-a834-6847182de684", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import BertTokenizer, BertForSequenceClassification" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "089f1126-7125-4274-9d71-5c949ccc7bbd", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import default_collate, RandomSampler, SequentialSampler" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2afac7d9-3d66-4195-8647-dc7034d306f2", + "metadata": {}, + "outputs": [], + "source": [ + "# Load IMDB dataset from huggingface datasets and select the \"train\" split\n", + "dataset = load_dataset(\"imdb\", streaming=False)\n", + "dataset = dataset[\"train\"]\n", + "# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n", + "# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n", + "# to migrate from torch.utils.data to torchdata.nodes\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "# Use a standard bert tokenizer\n", + "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", + "# Now we can set up some torchdata.nodes to create our pre-proc pipeline" + ] + }, + { + "cell_type": "markdown", + "id": "09e08a47-573c-4d32-9a02-36cd8150db60", + "metadata": {}, + "source": [ + "All torchdata.nodes.BaseNode implementations are Iterators.\n", + "MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", + "Under the hood, MapStyleWrapper just does:\n", + "```python\n", + "node = IterableWrapper(sampler)\n", + "node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02af5479-ee69-41d8-ab2d-bf154b84bc15", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "# Now we want to transform the raw inputs. We can just use another Mapper with\n", + "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", + "# threads (or processes) to parallelize this work and have it run in the background\n", + "max_len = 512\n", + "batch_size = 2\n", + "def bert_transform(item):\n", + " encoding = tokenizer.encode_plus(\n", + " item[\"text\"],\n", + " add_special_tokens=True,\n", + " max_length=max_len,\n", + " padding=\"max_length\",\n", + " truncation=True,\n", + " return_attention_mask=True,\n", + " return_tensors=\"pt\",\n", + " )\n", + " return {\n", + " \"input_ids\": encoding[\"input_ids\"].flatten(),\n", + " \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n", + " \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n", + " }\n", + "node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n", + "\n", + "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", + "# to stack the tensors between. We use torch.utils.data.default_collate for this\n", + "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", + "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", + "\n", + "# we can optionally apply pin_memory to the batches\n", + "if torch.cuda.is_available():\n", + " node = PinMemory(node)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "60fd54f3-62ef-47aa-a790-853cb4899f13", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n", + " [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", + " [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n" + ] + } + ], + "source": [ + "# Inspect a batch\n", + "batch = next(iter(loader))\n", + "print(batch)\n", + "# In a batch we get three keys, as defined in the method `bert_transform`.\n", + "# Since the batch size is 2, two samples are stacked together for each key." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}