From 383893e1129299880263b4ec71cdf61d69f33518 Mon Sep 17 00:00:00 2001 From: Ramanish Singh <48493179+ramanishsingh@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:54:00 -0800 Subject: [PATCH] Add a notebook for torchdata.nodes basics (#1393) * initial commit * Delete examples/nodes/.ipynb_checkpoints/hf_datasets_nodes_mnist-checkpoint.ipynb * dlt ckp folder * update names * simplify * update markdowns and fix some language errors * add SamplerWrapper * update with suggestions --- examples/nodes/torchata_nodes_basics.ipynb | 485 +++++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 examples/nodes/torchata_nodes_basics.ipynb diff --git a/examples/nodes/torchata_nodes_basics.ipynb b/examples/nodes/torchata_nodes_basics.ipynb new file mode 100644 index 000000000..97fe01b78 --- /dev/null +++ b/examples/nodes/torchata_nodes_basics.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ca0ddf27-b42d-4004-9f0c-4f43e55a245b", + "metadata": {}, + "source": [ + "### `torchdata.nodes` Basics" + ] + }, + { + "cell_type": "markdown", + "id": "f0ff94f5-fc60-4a7e-a3e5-0d7a4bfa2ddd", + "metadata": {}, + "source": [ + "All torchdata.nodes.BaseNode implementations are Iterators, adhering to the following API:\n", + "```Python\n", + "class BaseNode(Iterator[T]):\n", + " def reset(self, initial_state: Optional[Dict[str, Any]] = None) -> None:\n", + " \"\"\"Resets the node to its initial state or a specified state.\"\"\"\n", + " ...\n", + " def __next__(self) -> T:\n", + " \"\"\"Returns the next value in the sequence.\"\"\"\n", + " ...\n", + " def get_state(self) -> Dict[str, Any]:\n", + " \"\"\"Returns a dictionary representing the current state of the node.\"\"\"\n", + " ...\n", + "```\n", + "This standardized interface enables seamless chaining of iterators, allowing for flexible, efficient, and composable data processing pipelines." + ] + }, + { + "cell_type": "markdown", + "id": "5542e957-f40c-4624-9e3d-b4130d1fd03a", + "metadata": {}, + "source": [ + "Let's see the functionalities of `torchdata.nodes` through the help of a very simple example." + ] + }, + { + "cell_type": "markdown", + "id": "b5168248-7c08-419d-a9a5-f831bd7b46e7", + "metadata": {}, + "source": [ + "#### IterableWrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f42819f7-1019-48ca-ad23-00b38c72d63e", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import IterableWrapper\n", + "# This Wrapper converts any Iterable in to a BaseNode.\n", + "\n", + "dataset = range(10) # creating a very simple dataset, and then converting it into a BaseNode\n", + "source = IterableWrapper(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8bb313fe-5cad-4c5f-b7f3-4bd27b0645a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "2\n", + "3\n", + "4\n", + "5\n", + "6\n", + "7\n", + "8\n", + "9\n" + ] + } + ], + "source": [ + "# Let's take a look at the items in the node\n", + "for item in source:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "247845d5-ef4b-4bc3-92c3-93c55972578a", + "metadata": {}, + "source": [ + "#### Integrating with torch.data Dataloaders and Samplers" + ] + }, + { + "cell_type": "markdown", + "id": "97fb43c8-16c0-4c87-bb3a-ff6a0e126764", + "metadata": {}, + "source": [ + "We can also use `torch.data.utils` style dataloaders and samplers, and then wrap them into nodes.\n", + "Please refer to this [migration guide](https://pytorch.org/data/main/migrate_to_nodes_from_utils.html) to migrate from torch.utils.data to torchdata.nodes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a14e03a0-29f7-4824-8860-75f1f2a91533", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6\n", + "9\n", + "2\n", + "1\n", + "7\n", + "5\n", + "4\n", + "3\n", + "8\n", + "0\n" + ] + } + ], + "source": [ + "from torchdata.nodes import MapStyleWrapper\n", + "from torch.utils.data import RandomSampler\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "\n", + "for item in node:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "73f6e1ce-6e33-49cd-b0bb-c49fee0e7cd6", + "metadata": {}, + "source": [ + "#### Map" + ] + }, + { + "cell_type": "markdown", + "id": "2a5246cc-6f40-4942-b514-df76c3946eee", + "metadata": {}, + "source": [ + "We can use the Mapper class, to apply a transformation defined using the `map_fn`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "71a07a13-5d6f-46d4-b316-102698eb13c5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "4\n", + "9\n", + "16\n", + "25\n", + "36\n", + "49\n", + "64\n", + "81\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Mapper\n", + "node = Mapper(source, map_fn = lambda x : x**2)\n", + "for item in node:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "0701898c-6c72-4065-aee9-23a6b516876d", + "metadata": {}, + "source": [ + "It can also be executed in parallel, using the multi threading/processing approaches, depending on the defined `method`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0771661d-f0dc-4d35-a5af-abab84d7efd3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n", + "1\n", + "4\n", + "9\n", + "16\n", + "25\n", + "36\n", + "49\n", + "64\n", + "81\n" + ] + } + ], + "source": [ + "from torchdata.nodes import ParallelMapper\n", + "mapper = ParallelMapper(source, map_fn = lambda x : x**2, num_workers =2, method = \"thread\")\n", + "for item in mapper:\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "b517ebcf-9f92-43bf-b332-3875933ea244", + "metadata": {}, + "source": [ + "#### Batch" + ] + }, + { + "cell_type": "markdown", + "id": "6156bcb4-891f-484a-bf06-cbf39e282175", + "metadata": {}, + "source": [ + "A BaseNode can be passed into a Batcher, to get batches of size `batch_size`.\n", + "By default, `drop_last` is True, meaning if the last batch has a size smaller than the `batch_size`, it is not produced." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "dc40c637-90ed-4c09-9961-3b0d488fc6d8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3]\n", + "[4, 5, 6, 7]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Batcher\n", + "batcher = Batcher(source, batch_size = 4)\n", + "for batch in batcher:\n", + " print(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "5487e01d-6245-4940-a0df-c32aee0849cf", + "metadata": {}, + "source": [ + "We can make `drop_last = False` to produce the last batch" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "092d8e5b-aeb2-4de1-9b13-8b0dc0af31ce", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 3]\n", + "[4, 5, 6, 7]\n", + "[8, 9]\n" + ] + } + ], + "source": [ + "batcher = Batcher(source, batch_size = 4, drop_last = False)\n", + "for batch in batcher:\n", + " print(batch)" + ] + }, + { + "cell_type": "markdown", + "id": "e11543b9-c2dd-4861-af24-0eb3a1f3ddb9", + "metadata": {}, + "source": [ + "If we try to use this batcher over multiple epochs, we will need to reset it after every epoch" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d89e36a9-cc9c-4b55-acba-dbcaf8d53407", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", + "Epoch = 1 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n" + ] + } + ], + "source": [ + "batcher = Batcher(source, batch_size = 10)\n", + "num_epochs = 2\n", + "\n", + "for epoch in range(num_epochs):\n", + " for batch in batcher:\n", + " print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")\n", + " batcher.reset()\n", + " \n", + "# This is one extra step than traditional dataloader, we can actually wrap the batcher in a Loader to skip that\n", + "# Let's look at Loader in the next cell" + ] + }, + { + "cell_type": "markdown", + "id": "d2232823-0927-432b-9171-fb2bf38a7205", + "metadata": {}, + "source": [ + "#### Loader\n", + "As you can see, we get a batch in every epoch, without even needing to reset the loader!!" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4f8a1830-49a2-4dfd-b17b-acaca7ac2e98", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n", + "Epoch = 1 Batch = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import Loader\n", + "batcher = Batcher(source, batch_size = 10)\n", + "loader = Loader(batcher)\n", + "\n", + "for epoch in range(num_epochs):\n", + " for batch in loader:\n", + " print(f\"Epoch = {epoch}\", f\" Batch = {batch}\")" + ] + }, + { + "cell_type": "markdown", + "id": "860142c7-a39b-441b-b2b2-1cbd81a7a0f3", + "metadata": {}, + "source": [ + "#### SamplerWrapper" + ] + }, + { + "cell_type": "markdown", + "id": "18134e7d-d0e8-4096-a674-23e715fa16a3", + "metadata": {}, + "source": [ + "As mentioned earlier, we can use `torch.data.utils` samplers using `MapStyleWrapper`.\n", + "Alternatively, we can employ the `SamplerWrapper`, which converts a `Sampler` into a `BaseNode`. `SamplerWrapper` differs from `IterableWrapper` because it will track the number of epochs, and call the sampler's `set_epoch` method if it is implemented." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f794c4d9-b3ed-4b7b-8c85-131d495e6605", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch = 0 Batch = [2, 9, 0, 5, 8, 6, 7, 3, 1, 4]\n", + "Epoch = 1 Batch = [2, 3, 6, 8, 5, 0, 1, 9, 7, 4]\n" + ] + } + ], + "source": [ + "from torchdata.nodes import SamplerWrapper\n", + "\n", + "sampler = RandomSampler(dataset)\n", + "node = SamplerWrapper(sampler)\n", + "batcher = Batcher(node, batch_size = 10)\n", + "loader = Loader(batcher)\n", + "for epoch in range(num_epochs):\n", + " \n", + " for batch in loader:\n", + " print(f\"Epoch = {node.epoch}\", f\" Batch = {batch}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3e956793-df2d-4d0f-94b0-daa7fb860bf5", + "metadata": {}, + "source": [ + "`torchadata` nodes are composable, thus, many BaseNodes type nodes can be chained together for desired transformations" + ] + }, + { + "cell_type": "markdown", + "id": "5f1a6291-d688-45eb-bf2e-56153419513e", + "metadata": {}, + "source": [ + "#### Chaining multiple operations together\n", + "\n", + "Base nodes are iterators, and are designed to be chained together to create more complex dataloading graphs." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cc4bcd30-9a76-4a1e-aa38-d7db8b9a8c2e", + "metadata": {}, + "outputs": [], + "source": [ + "sampler = RandomSampler(dataset)\n", + "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", + "node = Mapper(node, map_fn = lambda x : x**3)\n", + "node = Batcher(node, batch_size = 4, drop_last = False)\n", + "loader = Loader(node)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "798f64af-102d-4b24-8327-b47d4d0fc4f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[64, 216, 27, 729]\n", + "[0, 8, 343, 1]\n", + "[512, 125]\n" + ] + } + ], + "source": [ + "for batch in loader:\n", + " print(batch)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}