Skip to content

Instantly share code, notes, and snippets.

@basavyr
Created December 18, 2025 09:40
Show Gist options
  • Select an option

  • Save basavyr/d3d34ddf3fb486644b729aabb788e9f1 to your computer and use it in GitHub Desktop.

Select an option

Save basavyr/d3d34ddf3fb486644b729aabb788e9f1 to your computer and use it in GitHub Desktop.
Applying tensor denormalization on standard datasets (e.g., CIFAR10, MNIST) for `imshow()` representation

This code provides a helper for retrieving standard datasets from Torchvision, such as MNIST, CIFAR10, etc. It moreover provides the denormalization step, which is required before graphical representations of the image tensors.

from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Grayscale, Normalize, RandomResizedCrop
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
import torch

from typing import Tuple
import os

DEFAULT_DATA_DIR = os.getenv("DEFAULT_DATA_DIR", None)

assert DEFAULT_DATA_DIR is not None, "Environment variable not set for < DEFAULT_DATA_DIR >"


def denorm(x_norm: torch.Tensor, dataset_type: str):
    if dataset_type == "mnist":
        mean = (0.1307,)
        std = (0.3081,)
    elif dataset_type == "cifar10":
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif dataset_type == "cifar100":
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    else:
        raise ValueError("Invalid denorm attributes")
    mean_tensor = torch.tensor(mean).view(-1, 1, 1)
    std_tensor = torch.tensor(std).view(-1, 1, 1)

    return (x_norm * std_tensor + mean_tensor).clamp(0.0, 1.0)


def get_dataloader(dataset_type: str, batch_size: int, resize_to: int = -1, force_3_channels: bool = False) -> Tuple[DataLoader, DataLoader, int, int, int]:
    if dataset_type == "mnist":
        train_loader, test_loader, img_size, in_channels, num_classes = get_mnist(
            batch_size, resize_to=resize_to, force_3_channels=force_3_channels)
    elif dataset_type == "cifar10":
        train_loader, test_loader, img_size, in_channels, num_classes = get_cifar_10(
            batch_size, resize_to=resize_to)
    elif dataset_type == "cifar100":
        train_loader, test_loader, img_size, in_channels, num_classes = get_cifar_100(
            batch_size, resize_to=resize_to)
    else:
        raise ValueError("Incorrect dataset type")

    return train_loader, test_loader, img_size, in_channels, num_classes


def get_mnist(batch_size: int, resize_to: int = -1, force_3_channels: bool = False) -> Tuple[DataLoader, DataLoader, int, int, int]:
    """
    Returns a tuple with:
    - A dataloader for the training data
    - A dataloader for the test/val data
    - Image width (e.g., size of image height and width in pixels. Images are squared)
    - Number of input channels (`in_channels`)
    - Number of classification labels (`num_classes`)
    """
    tf = Compose([])
    if force_3_channels:
        tf.transforms.append(Grayscale(num_output_channels=3))
    tf.transforms.append(ToTensor())
    tf.transforms.append(Normalize(mean=(0.1307,), std=(0.3081,)))
    # C, H, W
    num_channels = 3 if force_3_channels else 1
    img_size = 28
    if resize_to > 0:
        tf.transforms.append(RandomResizedCrop(resize_to))
        img_size = resize_to

    train_set = MNIST(root=DEFAULT_DATA_DIR, train=True,
                      transform=tf, download=True)
    test_set = MNIST(root=DEFAULT_DATA_DIR, train=False,
                     transform=tf, download=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, img_size, num_channels, 10


def get_cifar_10(batch_size: int, resize_to: int = -1) -> Tuple[DataLoader, DataLoader, int, int, int]:
    """
    Returns a tuple with:
    - A dataloader for the training data
    - A dataloader for the test/val data
    - Image width (e.g., size of image height and width in pixels. Images are squared)
    - Number of input channels (`in_channels`)
    - Number of classification labels (`num_classes`)
    """
    tf = Compose([])
    tf.transforms.append(ToTensor())
    tf.transforms.append(
        Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)))
    # C, H, W
    img_size = 32
    if resize_to > 0:
        tf.transforms.append(RandomResizedCrop(resize_to))
        img_size = resize_to

    train_set = CIFAR10(root=DEFAULT_DATA_DIR, train=True,
                        transform=tf, download=True)
    test_set = CIFAR10(root=DEFAULT_DATA_DIR, train=False,
                       transform=tf, download=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, img_size, 3, 10


def get_cifar_100(batch_size: int, resize_to: int = -1) -> Tuple[DataLoader, DataLoader, int, int, int]:
    """
    Returns a tuple with:
    - A dataloader for the training data
    - A dataloader for the test/val data
    - Image width (e.g., size of image height and width in pixels. Images are squared)
    - Number of input channels (`in_channels`)
    - Number of classification labels (`num_classes`)
    """
    tf = Compose([])
    tf.transforms.append(ToTensor())
    tf.transforms.append(
        Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)))
    # C, H, W
    img_size = 32
    if resize_to > 0:
        tf.transforms.append(RandomResizedCrop(resize_to))
        img_size = resize_to

    train_set = CIFAR100(root=DEFAULT_DATA_DIR, train=True,
                         transform=tf, download=True)
    test_set = CIFAR100(root=DEFAULT_DATA_DIR, train=False,
                        transform=tf, download=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, img_size, 3, 100
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment