From 55a7c089c13a015d0adbfd6d64be7a31fc85d32c Mon Sep 17 00:00:00 2001 From: camille-004 Date: Sun, 8 Dec 2024 09:10:49 -0800 Subject: [PATCH] fix: mypy issue --- nanofed/data/mnist.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/nanofed/data/mnist.py b/nanofed/data/mnist.py index 33a77ba..028d098 100644 --- a/nanofed/data/mnist.py +++ b/nanofed/data/mnist.py @@ -1,7 +1,6 @@ from pathlib import Path import numpy as np -from numpy.typing import NDArray from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms @@ -29,10 +28,9 @@ def load_mnist_data( if subset_fraction < 1.0: num_samples = int(len(dataset) * subset_fraction) - indices: NDArray[np.int64] = np.random.choice( - len(dataset), num_samples, replace=False - ) - subset_indices: list[int] = indices.tolist() + subset_indices = np.random.choice( + a=len(dataset), size=num_samples, replace=False + ).tolist() dataset = Subset(dataset, subset_indices) return DataLoader(