Open In App

How Does PyTorch Backprop Through Argmax?

Last Updated : 13 Dec, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Backpropagation is a fundamental algorithm in training neural networks, allowing them to learn from data. Backpropagation involves iteratively updating the weights of a neural network to minimize the difference between predicted and actual outputs. It relies on the chain rule to calculate gradients, determining how much each parameter should be adjusted. This allows the network to learn and improve its predictions over time.

However, certain operations, such as argmax, present challenges during backpropagation due to their non-differentiable nature. In this article, we delve into how PyTorch handles backpropagation through the argmax operation and explore techniques like the Straight-Through Estimator (STE) that make this possible.

Argmax Operation

In mathematical terms, the argmax function returns the input value at which a given function attains its maximum value. The argmax function is not continuous at points where the maximum value changes. It jumps from one value to another when you move across such points. Because of these discontinuities, the argmax function is not differentiable.

Consider the below example, in which we are finding the maximum index value of the tensor along the row.

Python3

import torch
 
torch.argmax(torch.tensor([[32,11,12,14],[1,123,12,212]]), dim = 1)

                    

Output:

tensor([0, 3])

where

  • the first value of 0 indicates index 0 of [32,11,12,14], implying 32 is the largest value.
  • the second value of 3 indicates index 3 of [1,123,12,212]-> implying 212 is the largest value

Non-Differentiable

The argmax function is discontinuous as it changes value suddenly. Also, the discontinuity point will vary depending on the values present in the tensor. Thus, the argmax function is non-differentiable. It does not have a well-defined derivative.

For training a neural network we need a differentiable function to backpropagate through loss for parameter updating. Therefore, the loss cannot be backpropagated through the network to update the weights if we have an the argmax function.

Application

In the context of neural networks, argmax is primarily a tool for interpreting or extracting the most likely results from the probabilistic outputs generated by the models. Some use cases where one can use argmax

  • Classification tasks: After training a neural network, argmax is often used in the inference phase to make predictions. For example, in a classification task, the argmax function is used to determine which class the network’s output corresponds to.
  • NLP: In sequence-to-sequence models for text generation, argmax is often used to choose the most likely next word or token at each step during decoding.
  • Conditional Sampling: In many generative neural networks like GAN and VAE, argmax is used for sampling to produce different results.

Since using argmax is paramount in various neural network architectures and can be updated, it becomes imperative to find ways to make it differentiable so that the loss can be backpropagated through the neural network and the model can update its weight.

Let us understand why argmax is not differentiable using code, and then we will discuss ways to counter it.
learn from

Error: Training a Neural Network with Argmax

Installations

!pip install torchinfo

Let us try to train a neural network with argmax as output and see what happens.

Import the necessary libraries

Python3

# Installing required libraries
import torchinfo
import torch
from torch import nn,optim
import torch.nn.functional as F

                    
  • Next, we create a model for classification of our image. We use two convolution layers each with ReLU activation and maxpooling followed by 3 fully connected layer . The final output of linear layer is of dimension 10. We pass this through an argmax .

Python3

# Creating our own LeNet5
 
 
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5# in channel , out channe, kernel
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d((2, 2))
 
        self.conv2 = nn.Conv2d(6, 16, 5# in channel , out channe, kernel
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d((2, 2))
 
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        # x = F.max_pool2d(F.relu(self.conv1(x)),(2, 2))
        # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
 
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
 
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
 
        x = torch.argmax(x, dim=1)
 
        return x
 
 
model = LeNet5()

                    
  • Let us load CIFAR10 dataset. The CIFAR-10 dataset consists of 60000 32×32 color images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We use transforms to convert the CIFAR 10 dataset to tensor and normalize it.

Python3

# Loading the dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10
train_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))])
train_data = CIFAR10(root="./train/", train=True, download=True, transform=train_transforms)
trainloader = torch.utils.data.DataLoader( train_data,batch_size=16, shuffle=True)

                    
  • Define our loss function and optimizer

Python3

# Our loss function
def my_loss(output, target):
    output = torch.tensor(output, dtype=torch.float)
    loss = ((output - target)**2).mean()
    return loss
# Our optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001, momentum=0.9)

                    
  • Train our model

Python3

# Training the model
 
N_EPOCHS = 2
for epoch in range(N_EPOCHS):
  epoch_loss = 0.0
  for inputs, labels in trainloader:
 
    optimizer.zero_grad()
 
    outputs = model(inputs)
 
    loss = my_loss(outputs, labels.float())
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
  print("Epoch: {} Loss: {}".format(epoch,epoch_loss/len(trainloader)))

                    

Output:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If we train our model it gives us RuntimeError. This is because we have used argmax in our model. The argmax does not have a gradient function.

Strategies for Handling Argmax in PyTorch

Now, let’s discuss the strategies we can use to address the challenges posed by the non-differentiable nature of argmax in PyTorch.

1. Straight-Through Estimator (STE)

This technique is used to handle non-differentiable operations during backpropagations in neural networks. It use a differentiable approximation during the backward pass to enable gradient flow through the argmax operation.

How it works :

In this method,

  • During the forward pass, we apply the argmax operation to determine the discrete category
  • During backward pass we let the gradient pass through as it is i.e. it acts as a identity function during backward pass which replicates the gradient of above layer and passes it as it is to the below layer without any modification. We pass the gradient through the argmax operation as if it were a continuous function.

2. Gumbel Softmax

Gumbel Softmax is another technique used for handling non-differentiable operations, particularly in the context of discrete variables such as the argmax operation. It introduce stochasticity to the argmax operation using the Gumbel Softmax distribution.

How it works :

This method is used for sampling from a continuous distribution .Lets understand with the help of an simple example:

  • Assume that there is discrete variables with two outcomes X1 and X2
  • Lets say we have a model with a hidden layer that produces a score for each of the outcomes of the discrete variable .
  • Let this score be S(X1) and S(X2). We take the highest score (argmax). The probability distribution will be given by
    • P(Xi) = softmax(S(X1),S(X2))
    • PX1 = \frac {e^{S(X_1)}} {∑e^{S(X_1) + S(X_2)}}
    • PX2 = \frac {e^{S(X_2)}} {∑e^{S(X_1) + S(X_2)}}
    • S(X1) and S(X2) are logits obtained from th mdoel
    • P(X1) and P(X2) is probability distribution for each of the categorical outcome
  • Now problem with above approach is that the sampling process is deterministic. We will always get the output with max index.
  • What if we want to actual sample in such a way that we get [1,0] for P(X1) times and [0,1] for P(X2) times. This is where Gumbel Max comes in.
  • The Gumbel max equation is:
    •  Y = argmax ({S(X_i)+G_i})
    • Where Gi are i.i.d Gumbel(0,1) distribution
    • NOTE : Mathemetically it can be shown that the Y will be distributed with probability distribution same as P(Xi)= softmax(S(X1),S(X2)) i.e. the equation will give output [1,0] for P(X1) times and [0,1] for P(X2) times.
  • Notice that there is still an argmax in Gumbel Max, which still makes it indifferentiable. Therefore, we use a softmax function to approximate this argmax procedure.
  • y =\frac{e^{(S(X_i) + Z_i)/\tau}}{ \Sigma e^{(S(X_i) + Z_i)/\tau}}
  • Here τ is a temperature hyperparameter. It controls the output variability.
    • When τ-> 0 the softmax becomes an argmax and the Gumbel-Softmax distribution becomes the categorical distribution.
    • During training, in order let to allow gradients past the sample we start with large value of τ , then gradually anneal the temperature (but not completely to 0, as the gradients would blow up).
  • It’s important to observe that the output of the Gumbel Softmax function produces a vector that sums to 1, somewhat resembling a one-hot vector (although it technically isn’t one). This approximation does not completely replace the argmax operation.
  • To genuinely obtain a pure one-hot vector, the Straight-Through (ST) Gumbel Trick is applied.
    • During the forward pass, the code utilizes an argmax operation to obtain an actual one-hot vector.
    • However, during backpropagation, the softmax approximation of the argmax is used to maintain differentiability.

Thus, this method involves introducing stochasticity(using Gumbel distribution) into the discrete decision-making process by using a differentiable approximation(softmax) to the argmax operation.

Gumbel-Softmax is often used in scenarios where you want to create a stochastic decision-making process involving discrete variables like variational encoders. It enables to backpropagate through random samples of discrete variables. Gumbel-Max Trick is very similar to the Reparameterization track whereby we are combining the deterministic part (the model score) with the stochastic part (Gumbel noise ).

3. Customized Operations

We can create custom PyTorch operations that mimic the behavior of argmax but are differentiable.

How it works :

Here we design a custom operation that approximates the argmax operation, and ensure that it’s differentiable. Example softmax

The softmax function converts the input value to a value between 0 and 1, where the sum is 1.

Given an input vector z = [z1, z2, …, zn], where n is the number of classes, the softmax function computes the probability pi for class i as follows:

p_i = \frac{ e(z_i)} { (ez_1 + ez_2 + ... + ez_n)}

In this equation:

  • ezi represents the exponential of the ith element of the input vector.
  • The denominator is the sum of the exponentials of all elements in the input vector, which ensures that the probabilities sum to 1.

Once we convert the output to probabilities we use negative log likelihood to calculate the loss:

L(y,p) = -\sum (y_i * log(p_i))

In this equation:

  • L(y, p) is the negative log-likelihood loss for the example.
  • y is a vector of true labels, where yi is 1 for the true class and 0 for other classes (one-hot encoded).
  • p is a vector of predicted probabilities, where pi is the predicted probability for class i.

Above loss equation will give maximum loss to lowest value and minimum loss to highest value

This loss function based on log is continuous and differentiable. Hence we can backpropagate the loss through the neural networks.

Implementation using Gumbel Softmax

A variational autoencoder is an architecture composed of both an encoder and a decoder that is trained to minimize the reconstruction error between the encoded-decoded data and the initial data. In a VAE, the Gumbel-Softmax is commonly used to sample from categorical distributions that represent discrete latent variables.

  1. The input is encoded as a distribution over the latent space. Here the encoded distributions are chosen to be normal so that the encoder can be trained to return the mean and the covariance matrix that describe these Gaussians.
  2. A point from the latent space is sampled from that distribution. Here we use Gumbel Softmax for sampling
    1. Gumbel Sampling: Gumbel noise is added to the output of encoder. The Gumbel noise is generated from a Gumbel distribution, and it’s typically denoted as g=−log⁡(−log⁡(u))g=−log(−log(u)), where u is sampled from a uniform distribution.
    2. Gumbel-Softmax: The Gumbel noise generated is added to the output of Encoder , and then the Softmax function is applied to obtain a differentiable sample from the categorical distribution. The Softmax function has a temperature parameter (τ) that controls the smoothness of the approximation. Initial the temp is kept at 1 so that different latent space vector can be be generated and then slowly annealed to zero so that the Gumbel-Softmax distribution approaches a true categorical distribution.
  3. The sampled point is decoded and the reconstruction error can be computed . The loss function consists of two terms:
    1. Generative loss: Compares the model output with the model input
    2. Latent loss: Compares the latent vector with a zero mean, unit variance Gaussian distribution in order to force it to be Normal. distributions returned by the encoder are enforced to be close to a standard normal distribution.

Now, let us train the model using softmax . The notebook is available at Notebook link.

1. Importing Necessary Libraries

Python3

import torch.utils.data
 
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
 
import torch.nn.functional as F
import numpy as np
import pandas as pd
import math

                    

2. Loading the data :

We will use MNIST handwritten digit data to train our VAE.

MNIST Dataset Loading:

  • mnist_dataset_train: Loads the MNIST training dataset. Here, transforms.ToTensor() converts the images to PyTorch tensors.
  • mnist_dataset_test: Similar to mnist_dataset_train, but for the testing dataset.

Batch Size and DataLoader:

  • batch_size = 128: Defines the batch size for training and testing.
  • train_loader and test_loader: These are instances of torch.utils.data.DataLoader, which is used to efficiently load and batch data. It takes the MNIST datasets and uses the specified batch size for training and testing. The shuffle=True argument ensures that the data is randomly shuffled during training, which can improve the learning process.

Python3

from torch.utils.data import DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
# Load the MNIST dataset
mnist_dataset_train = datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor())
# Load the MNIST dataset
mnist_dataset_test = datasets.MNIST(
    root='./data', train=True, download=True, transform=transforms.ToTensor())
 
 
batch_size = 128
train_loader = torch.utils.data.DataLoader(
    mnist_dataset_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    mnist_dataset_test, batch_size=batch_size, shuffle=True)

                    

3. Define Gumbel Softmax

Here is the breakdown of all the function

sample_gumbel

  • This function generates samples from a Gumbel distribution.
  • The function returns -torch.log(-torch.log(U + eps) + eps), which is a sample from the Gumbel distribution.

gumbel_softmax_sample

  • This function performs a Gumbel-Softmax sampling.
  • logits: The logits representing the output of the encoder block.
  • temperature: A temperature parameter controlling the level of smoothing in the sampling process.
  • Adds Gumbel noise to the logits and Returns the the softmax function with the specified temperature.

gumbel_softmax

  • This function represents the Gumbel-Softmax distribution and allows for sampling from it.
  • temperature: A temperature parameter controlling the level of smoothing in the sampling process
  • If hard is False, returns the soft sample as a flattened tensor.
  • If hard is True, returns a one-hot vector by finding the index of the maximum value in each row of the soft sample.

Python3

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape).to(device)
 
    return -torch.log(-torch.log(U + eps) + eps)
 
 
def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)
 
 
def gumbel_softmax(logits, temperature, hard=False):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
 
    if not hard:
        return y.view(-1, latent_dim * categorical_dim)
 
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, latent_dim * categorical_dim)

                    

Check the Gumbel Softmax performance

Python3

temp=1
latent_dim = 2
categorical_dim=5
input = torch.rand(1,latent_dim*categorical_dim)
print('Input:', input)
# With hard = False
print('\nGumbel Softmax with hard=False\n',gumbel_softmax(input, temp,False))
#With hard = True
print('\nGumbel Softmax with hard=True\n',gumbel_softmax(input, temp,  True))

                    

Output:

Input: tensor([[0.2976, 0.1200, 0.3637, 0.3025, 0.3605, 0.7416, 0.5763, 0.8461, 0.9991,
0.3869]])

Gumbel Softmax with hard=False
tensor([[0.0973, 0.0399, 0.0676, 0.0300, 0.0443, 0.0617, 0.2640, 0.1006, 0.2470,
0.0475]])

Gumbel Softmax with hard=True
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

4. Define the VAE

The encode method transforms input data into a latent space, the decode method reconstructs data from the latent space, and the forward method combines these steps, also returning the Gumbel-Softmax samples and softmax probabilities.

  1. Encoder
    • Performs 3 hidden layer transformations with ReLU activation. It reduces the size from 784 to latent_dim*categorical_dim
  2. Sampling
    • z = gumbel_softmax(q_y, temp, hard): Samples from the Gumbel-Softmax distribution. This is the latent variable. This latent variable is then passed to decoder
  3. Decoder
    • Reconstructs the image using 3 hidden layer transformations which is the reverse of encoder architecture

Python3

class VAE_gumbel(nn.Module):
    def __init__(self, temp):
        super(VAE_gumbel, self).__init__()
 
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)
 
        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)
 
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
 
    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))
 
    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))
 
    def forward(self, x, temp=0, hard=False):
        q = self.encode(x.view(-1, 784))
        q_y = q.view(q.size(0), latent_dim, categorical_dim)
        z = gumbel_softmax(q_y, temp, hard)
        return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())

                    


5. Define the loss function

This is the standard VAE loss. The loss function consists of two components: the BCE loss, which measures the difference between the reconstructed data and the original data, and the KL divergence loss, which penalizes the divergence between the distribution of the latent variables and a chosen prior distribution. The final loss is the sum of these two components.

Python3

latent_dim = 20
categorical_dim = 10  # one-of-K vector
 
temp = 1
temp_min = 0.5
ANNEAL_RATE = 0.00003
 
model = VAE_gumbel(temp).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
 
 
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, qy):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]
 
    log_ratio = torch.log(qy * categorical_dim + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()
 
    return BCE + KLD

                    

6. Define the train function

This training function iterates over batches of training data, performs forward and backward passes, updates the model parameters, and prints training progress. It also includes temperature annealing to control the exploration of the Gumbel-Softmax distribution during training

  • model.train(): Sets the model to training mode.
  • Moves the input data to the specified device and zeroes out the gradients accumulated in the optimizer.
  • Forward pass through the model to obtain the reconstructed batch and Gumbel-Softmax samples
  • Computes the loss using the reconstruction loss and KL divergence loss.
  • Backward pass to compute the gradients.
  • perform temperature annealing after every 100 batch
  • Print training progress after every 10 batch

Python3

def train(epoch, model, train_loader, optimizer, temp, cuda=True, hard=False):
  model.train()
  train_loss = 0
  for batch_idx, (data, _) in enumerate(train_loader):
    data.to(device)
    optimizer.zero_grad()
 
    recon_batch, q_y = model(data, temp, hard)
 
    loss = loss_function(recon_batch, data, q_y)
    loss.backward()
 
    train_loss += loss.item() * len(data)
    optimizer.step()
 
    if batch_idx % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)
 
    if batch_idx % 100 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader),
                  loss.item()))
 
  print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

                    

7. Define the test function

This code evaluates a given model on a test dataset, computing and printing the average reconstruction loss. Additionally, it saves a visual comparison of original and reconstructed images for the first batch.

  • model.eval(): Sets the model to evaluation mode. This is important because it disables gradient calculation like dropout, which are typically used during training but not during evaluation.
  • test_loss = 0: Initializes a variable to accumulate the test loss.
  • The function then iterates over the batches in the test_loader. For each batch:
    • Passes the input data through the model to obtain the reconstructed batch (recon_batch) and some output qy.
    • Computes the loss between the reconstructed batch and the original data using a loss_function. The loss is then added to the test_loss variable, scaled by the batch size.
    • Additionally, there’s an annealing mechanism for the temperature (temp) used in the model. If i % 100 == 1, the temperature is updated. This is likely part of some annealing schedule, where the temperature is reduced during training.
  • Finally, if i == 0, that is for every first batch we visualizs the reconstruction of the first batch by concatenating the original and reconstructed images and saving them using the save_image function.

Python3

def test(epoch, model, test_loader, temp, cuda=True, hard=False):
  model.eval()
  test_loss = 0
 
  for i, (data, _) in enumerate(test_loader):
    data.to(device)
 
    recon_batch, qy = model(data, temp, hard)
    test_loss += loss_function(recon_batch, data, qy).item() * len(data)
 
    if i % 100 == 1:
        temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)
 
    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],recon_batch.view(128, 1, 28, 28)[:n]])
        save_image(comparison.data.to(device),f"./reconstruction_{epoch:03d}.png", nrow=n)
 
  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))

                    

8. Training

We run the train and test function for 10 epochs

Python3

epochs = 10
prec = math.ceil(math.log10(epochs / 100))
 
 
model = VAE_gumbel(latent_dim)
model.to(device)
 
optimizer = optim.Adam(model.parameters(), lr=1e-3)
from torch.autograd import Variable
 
for epoch in range(1, epochs + 1):
    train(epoch, model, train_loader, optimizer, temp, True)
    test(epoch, model, test_loader, temp, True)

                    

Output:

Train Epoch: 1 [0/60000 (0%)]    Loss: 543.420044
Train Epoch: 1 [12800/60000 (21%)] Loss: 203.588531
Train Epoch: 1 [25600/60000 (43%)] Loss: 200.128128
Train Epoch: 1 [38400/60000 (64%)] Loss: 193.303040
Train Epoch: 1 [51200/60000 (85%)] Loss: 186.463058
====> Epoch: 1 Average loss: 202.1153
====> Test set loss: 178.8977
Train Epoch: 2 [0/60000 (0%)] Loss: 178.230606
Train Epoch: 2 [12800/60000 (21%)] Loss: 168.383270
Train Epoch: 2 [25600/60000 (43%)] Loss: 149.937363
Train Epoch: 2 [38400/60000 (64%)] Loss: 144.123688
Train Epoch: 2 [51200/60000 (85%)] Loss: 146.283325
====> Epoch: 2 Average loss: 156.7775
====> Test set loss: 143.5305
Train Epoch: 3 [0/60000 (0%)] Loss: 145.444778
Train Epoch: 3 [12800/60000 (21%)] Loss: 141.762360

Conclusion

In this article we saw how we can use different methods to make argmax differentiable . Making the argmax operation differentiable allows for gradient flow during backpropagation. Three commonly used methods are the Straight-Through Estimator, Gumbel-Softmax, and Custom Operations:

Straight-Through Estimator (STE): The Straight-Through Estimator is a simple and intuitive method for making the argmax operation differentiable which allows the incoming gradient to pass through as it is.

Gumbel-Softmax: The Gumbel-Softmax method uses parameterization trick to introduce Gumbel noise so as to make sampling differentiable and uses softmax to make argmax differentiable . We saw a detailed implementation of this using VAE.

Custom Operations: We can create custom, differentiable operations that approximate the argmax operation in a way that allows gradients to flow through.



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads