How is Pytorch’s Cross Entropy function related to softmax, log softmax, and NLL

Yang Zhang
2 min readOct 11, 2018

--

This notebook breaks down how `cross_entropy` function is implemented in pytorch, and how it is related to softmax, log_softmax, and NLL (negative log-likelihood).

Link to notebook:

import torch
import torch.nn as nn
import torch.nn.functional as F

Simulated x variable:

batch_size, n_classes = 5, 3
x = torch.randn(batch_size, n_classes)
x.shape

Out:

torch.Size([5, 3])

Run:

x

Out:

tensor([[ 0.9826,  1.0630, -0.4096],
[-0.6213, 0.2511, 0.5659],
[ 0.5662, 0.7360, -0.6783],
[-0.4638, -1.4961, -1.0877],
[ 1.8186, -0.2998, 0.1128]])

Simulated y variable:

target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)
target

Out:

tensor([1, 0, 1, 1, 1])

softmax + nl (negative likelihood)

This version is most similar to the math formula, but not numerically stable.

def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)
def nl(input, target): return -input[range(target.shape[0]), target].log().mean()

pred = softmax(x)
loss=nl(pred, target)
loss

Out:

tensor(1.4904)

Run:

pred = softmax(x)
loss=nl(pred, target)
loss

Out:

tensor(1.4904)

log_softmax + nll (negative log-likelihood)

https://pytorch.org/docs/stable/nn.html?highlight=logsoftmax#torch-nn-functional:

While mathematically equivalent to log(softmax(x)), doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly.

Run:

def log_softmax(x): return x - x.exp().sum(-1).log().unsqueeze(-1)
def nll(input, target): return -input[range(target.shape[0]), target].mean()

pred = log_softmax(x)
loss = nll(pred, target)
loss

Out:

tensor(1.4904)

F.log_softmax + F.nll_loss

The above but in pytorch.

pred = F.log_softmax(x, dim=-1)
loss = F.nll_loss(pred, target)
loss

Out:

tensor(1.4904)

F.cross_entropy

Pytorch's single cross_entropy function.

F.cross_entropy(x, target)

Out:

tensor(1.4904)

Reference:

For more details on the implementation of the functions above, see here for a side by side translation of all of Pytorch’s built-in loss functions to Python and Numpy.

--

--