This code provides a helper for retrieving standard datasets from Torchvision, such as MNIST, CIFAR10, etc. It moreover provides the denormalization step, which is required before graphical representations of the image tensors.
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Grayscale, Normalize, RandomResizedCrop
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
import torch
from typing import Tuple
import os
DEFAULT_DATA_DIR = os.getenv("DEFAULT_DATA_DIR", None)
assert DEFAULT_DATA_DIR is not None, "Environment variable not set for < DEFAULT_DATA_DIR >"
def denorm(x_norm: torch.Tensor, dataset_type: str):
if dataset_type == "mnist":
mean = (0.1307,)
std = (0.3081,)
elif dataset_type == "cifar10":
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
elif dataset_type == "cifar100":
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)
else:
raise ValueError("Invalid denorm attributes")
mean_tensor = torch.tensor(mean).view(-1, 1, 1)
std_tensor = torch.tensor(std).view(-1, 1, 1)
return (x_norm * std_tensor + mean_tensor).clamp(0.0, 1.0)
def get_dataloader(dataset_type: str, batch_size: int, resize_to: int = -1, force_3_channels: bool = False) -> Tuple[DataLoader, DataLoader, int, int, int]:
if dataset_type == "mnist":
train_loader, test_loader, img_size, in_channels, num_classes = get_mnist(
batch_size, resize_to=resize_to, force_3_channels=force_3_channels)
elif dataset_type == "cifar10":
train_loader, test_loader, img_size, in_channels, num_classes = get_cifar_10(
batch_size, resize_to=resize_to)
elif dataset_type == "cifar100":
train_loader, test_loader, img_size, in_channels, num_classes = get_cifar_100(
batch_size, resize_to=resize_to)
else:
raise ValueError("Incorrect dataset type")
return train_loader, test_loader, img_size, in_channels, num_classes
def get_mnist(batch_size: int, resize_to: int = -1, force_3_channels: bool = False) -> Tuple[DataLoader, DataLoader, int, int, int]:
"""
Returns a tuple with:
- A dataloader for the training data
- A dataloader for the test/val data
- Image width (e.g., size of image height and width in pixels. Images are squared)
- Number of input channels (`in_channels`)
- Number of classification labels (`num_classes`)
"""
tf = Compose([])
if force_3_channels:
tf.transforms.append(Grayscale(num_output_channels=3))
tf.transforms.append(ToTensor())
tf.transforms.append(Normalize(mean=(0.1307,), std=(0.3081,)))
# C, H, W
num_channels = 3 if force_3_channels else 1
img_size = 28
if resize_to > 0:
tf.transforms.append(RandomResizedCrop(resize_to))
img_size = resize_to
train_set = MNIST(root=DEFAULT_DATA_DIR, train=True,
transform=tf, download=True)
test_set = MNIST(root=DEFAULT_DATA_DIR, train=False,
transform=tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
return train_loader, test_loader, img_size, num_channels, 10
def get_cifar_10(batch_size: int, resize_to: int = -1) -> Tuple[DataLoader, DataLoader, int, int, int]:
"""
Returns a tuple with:
- A dataloader for the training data
- A dataloader for the test/val data
- Image width (e.g., size of image height and width in pixels. Images are squared)
- Number of input channels (`in_channels`)
- Number of classification labels (`num_classes`)
"""
tf = Compose([])
tf.transforms.append(ToTensor())
tf.transforms.append(
Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)))
# C, H, W
img_size = 32
if resize_to > 0:
tf.transforms.append(RandomResizedCrop(resize_to))
img_size = resize_to
train_set = CIFAR10(root=DEFAULT_DATA_DIR, train=True,
transform=tf, download=True)
test_set = CIFAR10(root=DEFAULT_DATA_DIR, train=False,
transform=tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
return train_loader, test_loader, img_size, 3, 10
def get_cifar_100(batch_size: int, resize_to: int = -1) -> Tuple[DataLoader, DataLoader, int, int, int]:
"""
Returns a tuple with:
- A dataloader for the training data
- A dataloader for the test/val data
- Image width (e.g., size of image height and width in pixels. Images are squared)
- Number of input channels (`in_channels`)
- Number of classification labels (`num_classes`)
"""
tf = Compose([])
tf.transforms.append(ToTensor())
tf.transforms.append(
Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)))
# C, H, W
img_size = 32
if resize_to > 0:
tf.transforms.append(RandomResizedCrop(resize_to))
img_size = resize_to
train_set = CIFAR100(root=DEFAULT_DATA_DIR, train=True,
transform=tf, download=True)
test_set = CIFAR100(root=DEFAULT_DATA_DIR, train=False,
transform=tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
return train_loader, test_loader, img_size, 3, 100