Module equilibrium-propagation.lib.data

Expand source code
import torch
from torchvision import datasets, transforms


def _one_hot_ten(label):
    """
    Helper function to convert to a one hot encoding with 10 classes.

    Args:
        label: target label as single number

    Returns:
        One-hot tensor with dimension (*, 10) encoding label
    """
    return torch.nn.functional.one_hot(torch.tensor(label), num_classes=10)


def create_mnist_loaders(batch_size):
    """
    Create dataloaders for the training and test set of MNIST.

    Args:
        batch_size: Number of samples per batch

    Returns:
        train_loader: torch.utils.data.DataLoader for the MNIST training set
        test_loader: torch.utils.data.DataLoader for the MNIST test set
    """

    # Load train and test MNIST datasets
    mnist_train = datasets.MNIST('../data/', train=True, download=True,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.1307,), (0.3081,)),
                                 ]),
                                 target_transform=_one_hot_ten
                                 )

    mnist_test = datasets.MNIST('../data/', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,)),
                                ]),
                                target_transform=_one_hot_ten
                                )

    # For GPU acceleration store dataloader in pinned (page-locked) memory
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

    # Create the dataloader objects
    train_loader = torch.utils.data.DataLoader(
        mnist_train, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        mnist_test, batch_size=batch_size, drop_last=True, shuffle=False, **kwargs)

    return train_loader, test_loader

Functions

def create_mnist_loaders(batch_size)

Create dataloaders for the training and test set of MNIST.

Args

batch_size
Number of samples per batch

Returns

train_loader
torch.utils.data.DataLoader for the MNIST training set
test_loader
torch.utils.data.DataLoader for the MNIST test set
Expand source code
def create_mnist_loaders(batch_size):
    """
    Create dataloaders for the training and test set of MNIST.

    Args:
        batch_size: Number of samples per batch

    Returns:
        train_loader: torch.utils.data.DataLoader for the MNIST training set
        test_loader: torch.utils.data.DataLoader for the MNIST test set
    """

    # Load train and test MNIST datasets
    mnist_train = datasets.MNIST('../data/', train=True, download=True,
                                 transform=transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.1307,), (0.3081,)),
                                 ]),
                                 target_transform=_one_hot_ten
                                 )

    mnist_test = datasets.MNIST('../data/', train=False, download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,)),
                                ]),
                                target_transform=_one_hot_ten
                                )

    # For GPU acceleration store dataloader in pinned (page-locked) memory
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

    # Create the dataloader objects
    train_loader = torch.utils.data.DataLoader(
        mnist_train, batch_size=batch_size, drop_last=True, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        mnist_test, batch_size=batch_size, drop_last=True, shuffle=False, **kwargs)

    return train_loader, test_loader