Skip to content

Commit

Permalink
Add an example for MultiNodeWeightedSampler (#1385)
Browse files Browse the repository at this point in the history
* initial commit

* update typos

* simplify

* include GH suggestions

* simplify

* increase n

* apply Andrew's suggestions
  • Loading branch information
ramanishsingh authored Dec 12, 2024
1 parent c0edd90 commit 9b1127a
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions examples/nodes/multi_dataset_weighted_sampling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "79a14c63-a085-493f-8db9-6af3e1d744b5",
"metadata": {},
"source": [
"### `MultiNodeWeightedSampler` example\n",
"In this notebook, we will explore the usage of `MultiNodeWeightedSampler` in `torchdata.nodes`.\n",
"\n",
"`MultiNodeWeightedSampler` allows us to sample with a probability from multiple datsets. We will make three datasets, and then see how does the composition of the output depends on the weights defined in the `MultiNodeWeightedSampler`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0b283748-9b3f-4b9e-bbc5-db0791f4d900",
"metadata": {},
"outputs": [],
"source": [
"from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader\n",
"import collections\n",
"\n",
"# defining a simple map_fn as a place holder example\n",
"def map_fn(item):\n",
" return {\"x\":item}\n",
"\n",
"\n",
"def constant_stream(value: int):\n",
" while True:\n",
" yield value\n",
"\n",
"# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper\n",
"num_datasets = 3\n",
"datasets = {\n",
" \"ds0\": IterableWrapper(constant_stream(0)),\n",
" \"ds1\": IterableWrapper(constant_stream(1)),\n",
" \"ds2\": IterableWrapper(constant_stream(2)),\n",
"}\n",
"\n",
"# Next, we have to define weights for sampling from a particular dataset\n",
"weights = {\"ds0\": 0.5, \"ds1\": 0.25, \"ds2\": 0.25}\n",
"\n",
"# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets\n",
"multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)\n",
"\n",
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
"# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
"loader = Loader(multi_node_sampler)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "77784ba3-b917-4083-aed4-dba2374110d5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fractions = {0: 0.49791, 2: 0.25067, 1: 0.25142}\n",
"The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}\n"
]
}
],
"source": [
"# Let's take a look at the output for 100k numbers, compute the fraction of each dataset in that batch\n",
"# and compare the batch composition with our given weights\n",
"n = 100000\n",
"it = iter(loader)\n",
"samples = [next(it) for _ in range(n)]\n",
"fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}\n",
"print(f\"fractions = {fractions}\")\n",
"print(f\"The original weights were = {weights}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 9b1127a

Please sign in to comment.