diff --git a/examples/nodes/multi_dataset_weighted_sampling.ipynb b/examples/nodes/multi_dataset_weighted_sampling.ipynb new file mode 100644 index 000000000..0a775552c --- /dev/null +++ b/examples/nodes/multi_dataset_weighted_sampling.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "79a14c63-a085-493f-8db9-6af3e1d744b5", + "metadata": {}, + "source": [ + "### `MultiNodeWeightedSampler` example\n", + "In this notebook, we will explore the usage of `MultiNodeWeightedSampler` in `torchdata.nodes`.\n", + "\n", + "`MultiNodeWeightedSampler` allows us to sample with a probability from multiple datsets. We will make three datasets, and then see how does the composition of the output depends on the weights defined in the `MultiNodeWeightedSampler`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0b283748-9b3f-4b9e-bbc5-db0791f4d900", + "metadata": {}, + "outputs": [], + "source": [ + "from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader\n", + "import collections\n", + "\n", + "# defining a simple map_fn as a place holder example\n", + "def map_fn(item):\n", + " return {\"x\":item}\n", + "\n", + "\n", + "def constant_stream(value: int):\n", + " while True:\n", + " yield value\n", + "\n", + "# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper\n", + "num_datasets = 3\n", + "datasets = {\n", + " \"ds0\": IterableWrapper(constant_stream(0)),\n", + " \"ds1\": IterableWrapper(constant_stream(1)),\n", + " \"ds2\": IterableWrapper(constant_stream(2)),\n", + "}\n", + "\n", + "# Next, we have to define weights for sampling from a particular dataset\n", + "weights = {\"ds0\": 0.5, \"ds1\": 0.25, \"ds2\": 0.25}\n", + "\n", + "# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets\n", + "multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)\n", + "\n", + "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", + "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", + "loader = Loader(multi_node_sampler)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "77784ba3-b917-4083-aed4-dba2374110d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fractions = {0: 0.49791, 2: 0.25067, 1: 0.25142}\n", + "The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}\n" + ] + } + ], + "source": [ + "# Let's take a look at the output for 100k numbers, compute the fraction of each dataset in that batch\n", + "# and compare the batch composition with our given weights\n", + "n = 100000\n", + "it = iter(loader)\n", + "samples = [next(it) for _ in range(n)]\n", + "fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}\n", + "print(f\"fractions = {fractions}\")\n", + "print(f\"The original weights were = {weights}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}