Skip to main content

PyTorch Loss Functions

In neural network training, loss functions play a critical role in guiding the learning process. A loss function (sometimes called a cost function or objective function) measures how well your model's predictions match the actual target values. In this tutorial, we'll explore the common loss functions available in PyTorch and learn how to implement them in your neural network projects.

Introduction to Loss Functions

Loss functions quantify the "error" between predicted values and actual values. During training, the optimization process aims to minimize this error by adjusting the network's weights and biases. Think of loss functions as a compass that guides your model toward better performance.

The choice of loss function depends on your specific task:

  • For regression problems (predicting continuous values), functions like Mean Squared Error (MSE) are common
  • For classification problems (predicting categories), functions like Cross-Entropy Loss are typically used

Let's dive into how PyTorch implements these crucial components.

Common Loss Functions in PyTorch

PyTorch provides a comprehensive set of loss functions in the torch.nn module. We'll explore the most commonly used ones and see practical examples of each.

Mean Squared Error (MSE)

MSE is the go-to loss function for regression problems. It calculates the average of the squared differences between predictions and actual values.

python
import torch
import torch.nn as nn

# Create MSE loss function
criterion = nn.MSELoss()

# Example data
predictions = torch.tensor([0.5, 1.8, 2.2], dtype=torch.float32)
targets = torch.tensor([1.0, 2.0, 2.0], dtype=torch.float32)

# Calculate loss
loss = criterion(predictions, targets)

print(f"Predictions: {predictions}")
print(f"Targets: {targets}")
print(f"MSE Loss: {loss.item()}")

Output:

Predictions: tensor([0.5000, 1.8000, 2.2000])
Targets: tensor([1.0000, 2.0000, 2.0000])
MSE Loss: 0.09666666388511658

The MSE formula is:

MSE=1ni=1n(yiy^i)2MSE = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2

Where:

  • yiy_i is the actual target value
  • y^i\hat{y}_i is the predicted value
  • nn is the number of samples

L1 Loss (Mean Absolute Error)

L1 Loss computes the mean absolute difference between the predicted and target values. It's less sensitive to outliers compared to MSE.

python
# Create L1 loss function
criterion = nn.L1Loss()

# Calculate loss
loss = criterion(predictions, targets)

print(f"L1 Loss: {loss.item()}")

Output:

L1 Loss: 0.23333333432674408

The L1 Loss formula is:

L1=1ni=1nyiy^iL1 = \frac{1}{n} \sum_{i=1}^n |y_i - \hat{y}_i|

Binary Cross-Entropy Loss

Binary Cross-Entropy Loss is used for binary classification problems (when there are only two classes). It measures the performance of a classification model whose output is a probability value between 0 and 1.

python
# Create BCE loss function
criterion = nn.BCELoss()

# Example data (must be between 0 and 1)
predictions = torch.tensor([0.7, 0.2, 0.9], dtype=torch.float32)
targets = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)

# Calculate loss
loss = criterion(predictions, targets)

print(f"Predictions: {predictions}")
print(f"Targets: {targets}")
print(f"BCE Loss: {loss.item()}")

Output:

Predictions: tensor([0.7000, 0.2000, 0.9000])
Targets: tensor([1.0000, 0.0000, 1.0000])
BCE Loss: 0.1550244391560555

Note that for BCELoss, your model must output values between 0 and 1 (typically using a sigmoid activation function). For convenience, PyTorch also provides BCEWithLogitsLoss that combines a sigmoid activation and BCE in one function:

python
# BCEWithLogitsLoss includes the sigmoid function
criterion = nn.BCEWithLogitsLoss()

# Example data (can be any real number)
logits = torch.tensor([0.5, -1.0, 2.0], dtype=torch.float32)
targets = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)

# Calculate loss
loss = criterion(logits, targets)

print(f"Logits: {logits}")
print(f"Targets: {targets}")
print(f"BCE With Logits Loss: {loss.item()}")

Output:

Logits: tensor([0.5000, -1.0000, 2.0000])
Targets: tensor([1.0000, 0.0000, 1.0000])
BCE With Logits Loss: 0.3157253861427307

Cross-Entropy Loss (for multi-class classification)

Cross-Entropy Loss is the standard loss function for multi-class classification problems. PyTorch provides CrossEntropyLoss which combines LogSoftmax and NLLLoss in a single function.

python
# Create Cross Entropy loss function
criterion = nn.CrossEntropyLoss()

# Example data - [batch_size, num_classes]
# Raw logits (unnormalized scores)
logits = torch.tensor([[0.2, 0.6, 0.3],
[1.2, 0.1, 0.5],
[0.2, 2.3, 0.1]], dtype=torch.float32)

# Class indices (not one-hot encoded)
targets = torch.tensor([1, 0, 1], dtype=torch.long)

# Calculate loss
loss = criterion(logits, targets)

print(f"Logits:\n{logits}")
print(f"Targets: {targets}")
print(f"Cross Entropy Loss: {loss.item()}")

Output:

Logits:
tensor([[0.2000, 0.6000, 0.3000],
[1.2000, 0.1000, 0.5000],
[0.2000, 2.3000, 0.1000]])
Targets: tensor([1, 0, 1])
Cross Entropy Loss: 0.6748845577239990

Important note: CrossEntropyLoss expects raw logits as input (not softmax outputs). It expects target values to be class indices (not one-hot encoded).

Negative Log-Likelihood Loss (NLL Loss)

NLL Loss is used when your model outputs log-probabilities. It's often used after a LogSoftmax layer.

python
# Create NLL loss function
criterion = nn.NLLLoss()

# Log probabilities (typically from a LogSoftmax layer)
log_probs = torch.tensor([[-0.5, -1.2, -0.8],
[-0.3, -2.1, -0.7],
[-2.3, -0.1, -1.8]], dtype=torch.float32)

# Class indices
targets = torch.tensor([0, 2, 1], dtype=torch.long)

# Calculate loss
loss = criterion(log_probs, targets)

print(f"Log Probabilities:\n{log_probs}")
print(f"Targets: {targets}")
print(f"NLL Loss: {loss.item()}")

Output:

Log Probabilities:
tensor([[-0.5000, -1.2000, -0.8000],
[-0.3000, -2.1000, -0.7000],
[-2.3000, -0.1000, -1.8000]])
Targets: tensor([0, 2, 1])
NLL Loss: 0.7666666507720947

Implementing Loss Functions in Neural Networks

Now let's see how loss functions fit into a complete neural network training loop. We'll create a simple model to classify the MNIST digits:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define a simple CNN model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = self.relu(self.fc1(x))
x = self.fc2(x) # No activation here - CrossEntropyLoss expects raw logits
return x

# Example of training loop (not executed)
def train_example():
# Set up transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

# Load training data
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize model, loss function and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 2
for epoch in range(num_epochs):
running_loss = 0.0
for i, (images, labels) in enumerate(train_loader):
# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = model(images)

# Calculate loss
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

# Print statistics
running_loss += loss.item()
if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
running_loss = 0.0

print('Finished Training')

# Note: This function is not executed in this tutorial
# train_example()

This example shows the standard pattern:

  1. Define your model architecture
  2. Choose an appropriate loss function
  3. Set up an optimizer
  4. In each training iteration:
    • Forward pass through the model to get predictions
    • Calculate loss between predictions and targets
    • Backpropagate to compute gradients
    • Update model parameters using the optimizer

Custom Loss Functions

Sometimes, you might need a loss function that isn't available in PyTorch. You can create your own by subclassing nn.Module:

python
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha

def forward(self, inputs, targets):
# Binary case
bce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')

# Apply focal scaling
pt = torch.exp(-bce_loss) # pt is the probability of being correct
focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss

return torch.mean(focal_loss)

# Example usage
focal_criterion = FocalLoss(gamma=2.0)
logits = torch.tensor([0.5, -1.0, 2.0], dtype=torch.float32)
targets = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)
loss = focal_criterion(logits, targets)
print(f"Focal Loss: {loss.item()}")

Output:

Focal Loss: 0.09205773472785950

The Focal Loss is particularly useful for imbalanced classification problems, as it puts more emphasis on hard-to-classify examples.

Choosing the Right Loss Function

Selecting the appropriate loss function depends on your specific task:

Task TypeRecommended Loss Function
RegressionMSE, L1Loss, SmoothL1Loss
Binary ClassificationBCELoss, BCEWithLogitsLoss
Multi-class ClassificationCrossEntropyLoss
Imbalanced ClassificationWeighted CrossEntropyLoss, Focal Loss
Generative ModelsCustom losses (e.g., Adversarial Loss)

Consider these factors when choosing:

  1. Type of problem: Classification vs. regression
  2. Data distribution: Balanced vs. imbalanced classes
  3. Outlier sensitivity: MSE is sensitive to outliers, L1 is more robust
  4. Mathematical properties: Some losses have better optimization characteristics
  5. Implementation details: Some losses combine activation functions

Loss Function Weighting

Sometimes certain examples or classes should have more influence on training. PyTorch allows for weighted loss functions:

python
# Weighted Cross Entropy - giving more importance to some classes
# Create weights for classes 0-9 (high weight = more important)
weights = torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0])
criterion = nn.CrossEntropyLoss(weight=weights)

# Example prediction and target
logits = torch.tensor([[0.2, 0.6, 0.3, 0.1, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1]], dtype=torch.float32)
target = torch.tensor([1], dtype=torch.long) # class 1, which has weight 2.0

loss = criterion(logits, target)
print(f"Weighted Loss: {loss.item()}")

Output:

Weighted Loss: 2.2534370422363281

Summary

Loss functions are the compass that guides neural network training. In this tutorial, we've covered:

  1. Common PyTorch Loss Functions: Including MSE, L1, Cross-Entropy, and NLL losses
  2. Implementation in Neural Networks: How to incorporate loss functions into a training loop
  3. Custom Loss Functions: Creating your own loss functions when needed
  4. Loss Function Selection: Guidelines for choosing the right loss for your task
  5. Loss Weighting: How to give more importance to certain classes or examples

Understanding loss functions is critical for effective neural network training. They define what your model is trying to optimize and directly impact the behavior and performance of your trained model.

Additional Resources

Exercises

  1. Implement a neural network for a regression problem using MSE loss
  2. Compare the training behavior using different loss functions (MSE vs. L1) on the same dataset
  3. Create a custom loss function that combines aspects of MSE and L1 loss
  4. Modify the MNIST example to use weighted Cross-Entropy loss to emphasize certain digits
  5. Experiment with different gamma values in Focal Loss and observe how they affect training on an imbalanced dataset


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)