diff --git a/nanofed/data/mnist.py b/nanofed/data/mnist.py index 33a77ba..028d098 100644 --- a/nanofed/data/mnist.py +++ b/nanofed/data/mnist.py @@ -1,7 +1,6 @@ from pathlib import Path import numpy as np -from numpy.typing import NDArray from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms @@ -29,10 +28,9 @@ def load_mnist_data( if subset_fraction < 1.0: num_samples = int(len(dataset) * subset_fraction) - indices: NDArray[np.int64] = np.random.choice( - len(dataset), num_samples, replace=False - ) - subset_indices: list[int] = indices.tolist() + subset_indices = np.random.choice( + a=len(dataset), size=num_samples, replace=False + ).tolist() dataset = Subset(dataset, subset_indices) return DataLoader(