How is Pytorch’s Cross Entropy function related to softmax, log softmax, and NLL
This notebook breaks down how `cross_entropy` function is implemented in pytorch, and how it is related to softmax, log_softmax, and NLL (negative log-likelihood).
Link to notebook:
import torch
import torch.nn as nn
import torch.nn.functional as F
Simulated x
variable:
batch_size, n_classes = 5, 3
x = torch.randn(batch_size, n_classes)
x.shape
Out:
torch.Size([5, 3])
Run:
x
Out:
tensor([[ 0.9826, 1.0630, -0.4096],
[-0.6213, 0.2511, 0.5659],
[ 0.5662, 0.7360, -0.6783],
[-0.4638, -1.4961, -1.0877],
[ 1.8186, -0.2998, 0.1128]])
Simulated y
variable:
target = torch.randint(n_classes, size=(batch_size,), dtype=torch.long)
target
Out:
tensor([1, 0, 1, 1, 1])
softmax
+ nl
(negative likelihood)
This version is most similar to the math formula, but not numerically stable.
def softmax(x): return x.exp() / (x.exp().sum(-1)).unsqueeze(-1)
def nl(input, target): return -input[range(target.shape[0]), target].log().mean()
pred = softmax(x)
loss=nl(pred, target)
loss
Out:
tensor(1.4904)
Run:
pred = softmax(x)
loss=nl(pred, target)
loss
Out:
tensor(1.4904)
log_softmax
+ nll
(negative log-likelihood)
https://pytorch.org/docs/stable/nn.html?highlight=logsoftmax#torch-nn-functional:
While mathematically equivalent to
log(softmax(x))
, doing these two operations separately is slower, and numerically unstable. This function uses an alternative formulation to compute the output and gradient correctly.
Run:
def log_softmax(x): return x - x.exp().sum(-1).log().unsqueeze(-1)
def nll(input, target): return -input[range(target.shape[0]), target].mean()
pred = log_softmax(x)
loss = nll(pred, target)
loss
Out:
tensor(1.4904)
F.log_softmax
+ F.nll_loss
The above but in pytorch.
pred = F.log_softmax(x, dim=-1)
loss = F.nll_loss(pred, target)
loss
Out:
tensor(1.4904)
F.cross_entropy
Pytorch's single cross_entropy function.
F.cross_entropy(x, target)
Out:
tensor(1.4904)
Reference:
For more details on the implementation of the functions above, see here for a side by side translation of all of Pytorch’s built-in loss functions to Python and Numpy.