Module equilibrium-propagation.lib.train

Expand source code
import logging

import torch

from lib import config


def predict_batch(model, x_batch, dynamics, fast_init):
    """
    Compute the softmax prediction probabilities for a given data batch.

    Args:
        model: EnergyBasedModel
        x_batch: Batch of input tensors
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction

    Returns:
        Softmax classification probabilities for the given data batch
    """
    # Initialize the neural state variables
    model.reset_state()

    # Clamp the input to the test sample, and remove nudging from ouput
    model.clamp_layer(0, x_batch.view(-1, model.dimensions[0]))
    model.set_C_target(None)

    # Generate the prediction
    if fast_init:
        model.fast_init()
    else:
        model.u_relax(**dynamics)

    return torch.nn.functional.softmax(model.u[-1].detach(), dim=1)


def test(model, test_loader, dynamics, fast_init):
    """
    Evaluate prediction accuracy of an energy-based model on a given test set.

    Args:
        model: EnergyBasedModel
        test_loader: Dataloader containing the test dataset
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction

    Returns:
        Test accuracy
        Mean energy of the model per batch
    """
    test_E, correct, total = 0.0, 0.0, 0.0

    for x_batch, y_batch in test_loader:
        # Prepare the new batch
        x_batch, y_batch = x_batch.to(config.device), y_batch.to(config.device)

        # Extract prediction as the output unit with the strongest activity
        output = predict_batch(model, x_batch, dynamics, fast_init)
        prediction = torch.argmax(output, 1)

        with torch.no_grad():
            # Compute test batch accuracy, energy and store number of seen batches
            correct += float(torch.sum(prediction == y_batch.argmax(dim=1)))
            test_E += float(torch.sum(model.E))
            total += x_batch.size(0)

    return correct/total, test_E/total


def train(model, train_loader, dynamics, w_optimizer, fast_init):
    """
    Use equilibrium propagation to train an energy-based model.

    Args:
        model: EnergyBasedModel
        train_loader: Dataloader containing the training dataset
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        w_optimizer: torch.optim.Optimizer object for the model parameters
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction
    """
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        x_batch, y_batch = x_batch.to(config.device), y_batch.to(config.device)

        # Reinitialize the neural state variables
        model.reset_state()

        # Clamp the input to the training sample
        model.clamp_layer(0, x_batch.view(-1, model.dimensions[0]))

        # Free phase
        if fast_init:
            # Skip the free phase using fast feed-forward initialization instead
            model.fast_init()
            free_grads = [torch.zeros_like(p) for p in model.parameters()]
        else:
            # Run free phase until settled to fixed point and collect the free phase derivates
            model.set_C_target(None)
            dE = model.u_relax(**dynamics)
            free_grads = model.w_get_gradients()

        # Run nudged phase until settled to fixed point and collect the nudged phase derivates
        model.set_C_target(y_batch)
        dE = model.u_relax(**dynamics)
        nudged_grads = model.w_get_gradients()

        # Optimize the parameters using the contrastive Hebbian style update
        model.w_optimize(free_grads, nudged_grads, w_optimizer)

        # Logging key statistics
        if batch_idx % (len(train_loader)//10) == 0:

            # Extract prediction as the output unit with the strongest activity
            output = predict_batch(model, x_batch, dynamics, fast_init)
            prediction = torch.argmax(output, 1)

            # Log energy and batch accuracy
            batch_acc = float(torch.sum(prediction == y_batch.argmax(dim=1))) / x_batch.size(0)
            logging.info('{:.0f}%:\tE: {:.2f}\tdE {:.2f}\tbatch_acc {:.4f}'.format(
                100. * batch_idx / len(train_loader), torch.mean(model.E), dE, batch_acc))

Functions

def predict_batch(model, x_batch, dynamics, fast_init)

Compute the softmax prediction probabilities for a given data batch.

Args

model
EnergyBasedModel
x_batch
Batch of input tensors
dynamics
Dictionary containing the keyword arguments for the relaxation dynamics on u
fast_init
Boolean to specify if fast feedforward initilization is used for the prediction

Returns

Softmax classification probabilities for the given data batch
 
Expand source code
def predict_batch(model, x_batch, dynamics, fast_init):
    """
    Compute the softmax prediction probabilities for a given data batch.

    Args:
        model: EnergyBasedModel
        x_batch: Batch of input tensors
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction

    Returns:
        Softmax classification probabilities for the given data batch
    """
    # Initialize the neural state variables
    model.reset_state()

    # Clamp the input to the test sample, and remove nudging from ouput
    model.clamp_layer(0, x_batch.view(-1, model.dimensions[0]))
    model.set_C_target(None)

    # Generate the prediction
    if fast_init:
        model.fast_init()
    else:
        model.u_relax(**dynamics)

    return torch.nn.functional.softmax(model.u[-1].detach(), dim=1)
def test(model, test_loader, dynamics, fast_init)

Evaluate prediction accuracy of an energy-based model on a given test set.

Args

model
EnergyBasedModel
test_loader
Dataloader containing the test dataset
dynamics
Dictionary containing the keyword arguments for the relaxation dynamics on u
fast_init
Boolean to specify if fast feedforward initilization is used for the prediction

Returns

Test accuracy
 
Mean energy of the model per batch
 
Expand source code
def test(model, test_loader, dynamics, fast_init):
    """
    Evaluate prediction accuracy of an energy-based model on a given test set.

    Args:
        model: EnergyBasedModel
        test_loader: Dataloader containing the test dataset
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction

    Returns:
        Test accuracy
        Mean energy of the model per batch
    """
    test_E, correct, total = 0.0, 0.0, 0.0

    for x_batch, y_batch in test_loader:
        # Prepare the new batch
        x_batch, y_batch = x_batch.to(config.device), y_batch.to(config.device)

        # Extract prediction as the output unit with the strongest activity
        output = predict_batch(model, x_batch, dynamics, fast_init)
        prediction = torch.argmax(output, 1)

        with torch.no_grad():
            # Compute test batch accuracy, energy and store number of seen batches
            correct += float(torch.sum(prediction == y_batch.argmax(dim=1)))
            test_E += float(torch.sum(model.E))
            total += x_batch.size(0)

    return correct/total, test_E/total
def train(model, train_loader, dynamics, w_optimizer, fast_init)

Use equilibrium propagation to train an energy-based model.

Args

model
EnergyBasedModel
train_loader
Dataloader containing the training dataset
dynamics
Dictionary containing the keyword arguments for the relaxation dynamics on u
w_optimizer
torch.optim.Optimizer object for the model parameters
fast_init
Boolean to specify if fast feedforward initilization is used for the prediction
Expand source code
def train(model, train_loader, dynamics, w_optimizer, fast_init):
    """
    Use equilibrium propagation to train an energy-based model.

    Args:
        model: EnergyBasedModel
        train_loader: Dataloader containing the training dataset
        dynamics: Dictionary containing the keyword arguments
            for the relaxation dynamics on u
        w_optimizer: torch.optim.Optimizer object for the model parameters
        fast_init: Boolean to specify if fast feedforward initilization
            is used for the prediction
    """
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        x_batch, y_batch = x_batch.to(config.device), y_batch.to(config.device)

        # Reinitialize the neural state variables
        model.reset_state()

        # Clamp the input to the training sample
        model.clamp_layer(0, x_batch.view(-1, model.dimensions[0]))

        # Free phase
        if fast_init:
            # Skip the free phase using fast feed-forward initialization instead
            model.fast_init()
            free_grads = [torch.zeros_like(p) for p in model.parameters()]
        else:
            # Run free phase until settled to fixed point and collect the free phase derivates
            model.set_C_target(None)
            dE = model.u_relax(**dynamics)
            free_grads = model.w_get_gradients()

        # Run nudged phase until settled to fixed point and collect the nudged phase derivates
        model.set_C_target(y_batch)
        dE = model.u_relax(**dynamics)
        nudged_grads = model.w_get_gradients()

        # Optimize the parameters using the contrastive Hebbian style update
        model.w_optimize(free_grads, nudged_grads, w_optimizer)

        # Logging key statistics
        if batch_idx % (len(train_loader)//10) == 0:

            # Extract prediction as the output unit with the strongest activity
            output = predict_batch(model, x_batch, dynamics, fast_init)
            prediction = torch.argmax(output, 1)

            # Log energy and batch accuracy
            batch_acc = float(torch.sum(prediction == y_batch.argmax(dim=1))) / x_batch.size(0)
            logging.info('{:.0f}%:\tE: {:.2f}\tdE {:.2f}\tbatch_acc {:.4f}'.format(
                100. * batch_idx / len(train_loader), torch.mean(model.E), dE, batch_acc))