Skip to main content

PyTorch Style Transfer

Introduction

Neural Style Transfer is one of the most fascinating applications of deep learning in computer vision. It allows us to take two images - a content image (like a photograph) and a style image (like a painting) - and blend them together so the output looks like the content image painted in the style of the style image.

In this tutorial, we'll learn how to implement Neural Style Transfer using PyTorch. The technique was first introduced in the paper "A Neural Algorithm of Artistic Style" by Gatys et al., and has since become a popular application of Convolutional Neural Networks (CNNs).

Style Transfer Example

Prerequisites

Before diving into style transfer, make sure you're familiar with:

  • Basic Python programming
  • PyTorch fundamentals
  • Convolutional Neural Networks (CNNs)
  • Gradient descent optimization

Understanding Neural Style Transfer

Neural Style Transfer works by defining and optimizing two different types of losses:

  1. Content Loss: Ensures the generated image has the same content as the content image
  2. Style Loss: Makes the generated image adopt the artistic style of the style image

We'll be using a pre-trained CNN (typically VGG19) and manipulating its intermediate features to achieve our goal.

Implementation

Let's walk through the step-by-step implementation of Neural Style Transfer in PyTorch:

Step 1: Import the necessary libraries

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

Step 2: Set up the device and other configurations

python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512 if torch.cuda.is_available() else 128 # Use smaller size if no GPU

loader = transforms.Compose([
transforms.Resize(imsize), # Scale imported image
transforms.ToTensor()]) # Transform it into a tensor

Step 3: Function to load images

python
def image_loader(image_name):
image = Image.open(image_name)
# Add fake batch dimension required to fit network's input dimensions
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)

# Load images
content_img = image_loader('content_image.jpg')
style_img = image_loader('style_image.jpg')

# Create a white noise image as the starting point for our generated image
input_img = torch.randn(content_img.data.size(), device=device)

# Display images
def imshow(tensor, title=None):
image = tensor.cpu().clone() # clone the tensor to not change it
image = image.squeeze(0) # remove the fake batch dimension
unloader = transforms.ToPILImage()
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated

plt.figure()
imshow(style_img, title='Style Image')
plt.figure()
imshow(content_img, title='Content Image')

Step 4: Define the model and loss functions

We'll use the pre-trained VGG19 model for our style transfer:

python
class ContentLoss(nn.Module):
def __init__(self, target,):
super(ContentLoss, self).__init__()
# We 'detach' the target content from the tree used
self.target = target.detach()

def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input

def gram_matrix(input):
batch_size, f_maps, h, w = input.size() # a=batch size, b=feature maps, (c,d)=dimensions
features = input.view(batch_size * f_maps, h * w)
G = torch.mm(features, features.t()) # compute the gram matrix
# normalize the values of the gram matrix
return G.div(batch_size * f_maps * h * w)

class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()

def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input

Step 5: Import and set up the model

python
# VGG networks are trained on images with each channel normalized
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)

def forward(self, img):
return (img - self.mean) / self.std

# Import the model
cnn = models.vgg19(pretrained=True).features.to(device).eval()

Step 6: Set up the model with content and style loss layers

python
# Desired layers to compute content/style losses
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,
content_layers=content_layers,
style_layers=style_layers):
normalization = Normalization(normalization_mean, normalization_std).to(device)

# Just in order to have an iterable access to or list of content/style losses
content_losses = []
style_losses = []

# Assuming that cnn is a nn.Sequential
model = nn.Sequential(normalization)

i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv_{i}'
elif isinstance(layer, nn.ReLU):
name = f'relu_{i}'
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place ones here
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = f'pool_{i}'
elif isinstance(layer, nn.BatchNorm2d):
name = f'bn_{i}'
else:
raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}')

model.add_module(name, layer)

if name in content_layers:
# Add content loss
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module(f"content_loss_{i}", content_loss)
content_losses.append(content_loss)

if name in style_layers:
# Add style loss
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module(f"style_loss_{i}", style_loss)
style_losses.append(style_loss)

# Now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break

model = model[:(i + 1)]

return model, style_losses, content_losses

Step 7: Define the optimization process

python
def get_input_optimizer(input_img):
# Use LBFGS as optimizer since we're optimizing an image
optimizer = optim.LBFGS([input_img])
return optimizer

def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300,
style_weight=1000000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(
cnn, normalization_mean, normalization_std, style_img, content_img)

# We want to optimize the input and not the model parameters
input_img.requires_grad_(True)
# We also put the model in evaluation mode
model.eval()
model.requires_grad_(False)

optimizer = get_input_optimizer(input_img)

print('Optimizing..')
run = [0]
while run[0] <= num_steps:

def closure():
# Correct the values of updated input image
with torch.no_grad():
input_img.clamp_(0, 1)

optimizer.zero_grad()
model(input_img)

style_score = 0
content_score = 0

for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss

style_score *= style_weight
content_score *= content_weight

loss = style_score + content_score
loss.backward()

run[0] += 1
if run[0] % 50 == 0:
print(f"run {run[0]}:")
print(f'Style Loss : {style_score.item():4f} Content Loss: {content_score.item():4f}')
print()

return style_score + content_score

optimizer.step(closure)

# Final correction
with torch.no_grad():
input_img.clamp_(0, 1)

return input_img

Step 8: Run the optimization and display results

python
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img)

plt.figure()
imshow(output, title='Output Image')

# Save the image
output_img = output.cpu().clone()
output_img = output_img.squeeze(0) # remove the fake batch dimension
unloader = transforms.ToPILImage()
output_img = unloader(output_img)
output_img.save('output.jpg')

plt.show()

Example Results

Here's an example of the neural style transfer in action:

Content Image (Photograph of a city):

[Content image would be displayed here]

Style Image (Van Gogh's Starry Night):

[Style image would be displayed here]

Result after Style Transfer:

[Result image would be displayed here]

Understanding the Code Components

1. Content Loss

The content loss ensures that the high-level features in the generated image match those in the content image. We:

  • Extract feature maps from specific layers of the VGG network
  • Compare the feature maps of the content image with those of the generated image
  • Minimize the mean squared error between them

2. Style Loss

The style loss ensures that the generated image captures the artistic style of the style image. We:

  • Compute Gram matrices of feature maps from specific layers
  • The Gram matrix captures texture information by measuring correlations between features
  • Compare the Gram matrices from the style image with those from the generated image
  • Minimize the mean squared error between them

3. Optimization Process

We optimize the pixel values of the generated image by:

  • Starting with random noise or the content image
  • Feeding it through the VGG network
  • Computing content and style losses
  • Using backpropagation to update pixel values
  • Balancing content and style with weights (style_weight and content_weight)

Customizing Style Transfer

Here are some ways to experiment with and customize style transfer:

Adjusting the Weights

python
# Give more importance to content preservation
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img,
style_weight=100000, content_weight=10)

# Give more importance to style transfer
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img,
style_weight=10000000, content_weight=0.1)

Using Different Layers

You can adjust which layers are used for content and style computation:

python
# For more abstract content representation
content_layers = ['conv_5']

# For more fine-grained style details
style_layers = ['conv_1', 'conv_2', 'conv_3']

Using Different Starting Images

Instead of random noise, you can start with the content image:

python
input_img = content_img.clone()

Real-World Applications

Neural Style Transfer has several practical applications:

  1. Art and Design: Creating unique artworks by applying famous painting styles to photographs
  2. Film and Entertainment: Stylizing movie scenes or video game environments
  3. Mobile Apps: Photo filters that apply artistic styles (like Prisma app)
  4. Fashion: Designing new patterns and textures for clothing
  5. Advertising: Creating visually striking marketing materials by combining brand imagery with artistic styles

Summary

In this tutorial, we've learned how to implement Neural Style Transfer using PyTorch. We've covered:

  • The theory behind style transfer using CNNs
  • How to extract and manipulate features from a pre-trained VGG network
  • Computing content and style losses to blend images
  • Optimizing an image to achieve the desired stylistic effect
  • Customizing the process for different results

Neural Style Transfer is not only a fascinating application of deep learning but also a window into understanding how CNNs represent and process visual information.

Additional Resources

Exercises

  1. Try applying style transfer with your own content and style images
  2. Experiment with different weight values for content and style losses
  3. Try using different layers of the VGG network for content and style computation
  4. Implement a version that works on video by maintaining consistency between frames
  5. Create a simple web app that allows users to upload images and apply different styles
  6. Research and implement the fast neural style transfer approach that uses a feed-forward network

Happy styling!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)