Skip to main content

PyTorch Tensor Indexing

Introduction

When working with PyTorch tensors, you often need to access specific elements, rows, columns, or subsets of data. This process, known as indexing, is fundamental for data manipulation in deep learning applications. In this tutorial, we'll explore the various ways to index and slice PyTorch tensors—skills that will help you efficiently manipulate data for your machine learning models.

PyTorch's indexing syntax is largely based on NumPy's, so if you're familiar with NumPy arrays, you'll find many similarities. However, there are some PyTorch-specific features worth learning.

Basic Tensor Indexing

Let's start by creating a tensor and accessing individual elements.

python
import torch

# Create a 1D tensor
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
print(f"Original tensor: {x}")

# Access single elements
print(f"First element: {x[0]}")
print(f"Last element: {x[-1]}")
print(f"Third element: {x[2]}")

Output:

Original tensor: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
First element: tensor(0)
Last element: tensor(9)
Third element: tensor(2)

For multi-dimensional tensors, we can specify indices for each dimension:

python
# Create a 2D tensor (matrix)
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print(f"Original matrix:\n{matrix}")

# Access elements
print(f"Element at row 0, column 1: {matrix[0, 1]}")
print(f"Element at row 2, column 2: {matrix[2, 2]}")

Output:

Original matrix:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Element at row 0, column 1: tensor(2)
Element at row 2, column 2: tensor(9)

Slicing Tensors

Slicing allows you to extract a range of elements from a tensor. The syntax is tensor[start:end:step], where:

  • start is the starting index (inclusive, defaults to 0)
  • end is the ending index (exclusive, defaults to the size of dimension)
  • step is the step size (defaults to 1)

1D Tensor Slicing

python
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# Get elements from index 2 to 5 (exclusive)
print(f"Elements from index 2 to 5: {x[2:5]}")

# Get every second element
print(f"Every second element: {x[::2]}")

# Get elements from index 1 to 7 with step 2
print(f"Elements 1 to 7 with step 2: {x[1:8:2]}")

# Negative indexing (from the end)
print(f"Last 3 elements: {x[-3:]}")
print(f"All elements except the last 2: {x[:-2]}")

# Reverse the tensor
print(f"Reversed tensor: {x[::-1]}")

Output:

Elements from index 2 to 5: tensor([2, 3, 4])
Every second element: tensor([0, 2, 4, 6, 8])
Elements 1 to 7 with step 2: tensor([1, 3, 5, 7])
Last 3 elements: tensor([7, 8, 9])
All elements except the last 2: tensor([0, 1, 2, 3, 4, 5, 6, 7])
Reversed tensor: tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])

2D Tensor Slicing

With multi-dimensional tensors, you can slice along each dimension separately:

python
matrix = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
print(f"Original matrix:\n{matrix}")

# Get the first row
print(f"First row: {matrix[0]}")

# Get the first column
print(f"First column: {matrix[:, 0]}")

# Get a 2x2 submatrix (top-left corner)
print(f"Top-left 2x2 submatrix:\n{matrix[0:2, 0:2]}")

# Get rows 0 and 2, and all columns
print(f"Rows 0 and 2:\n{matrix[[0, 2]]}")

# Get alternate rows and columns
print(f"Alternate rows and columns:\n{matrix[::2, ::2]}")

Output:

Original matrix:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
First row: tensor([1, 2, 3, 4])
First column: tensor([1, 5, 9])
Top-left 2x2 submatrix:
tensor([[1, 2],
[5, 6]])
Rows 0 and 2:
tensor([[ 1, 2, 3, 4],
[ 9, 10, 11, 12]])
Alternate rows and columns:
tensor([[ 1, 3],
[ 9, 11]])

Advanced Indexing Techniques

PyTorch provides more advanced indexing methods for complex operations.

Boolean Indexing

You can use boolean masks to filter tensor elements:

python
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# Create a boolean mask
mask = x > 5
print(f"Boolean mask: {mask}")

# Apply the mask to filter elements
filtered = x[mask]
print(f"Elements > 5: {filtered}")

# One-line filtering
print(f"Elements divisible by 2: {x[x % 2 == 0]}")
print(f"Elements between 3 and 7: {x[(x >= 3) & (x <= 7)]}")

Output:

Boolean mask: tensor([False, False, False, False, False, False, True, True, True, True])
Elements > 5: tensor([6, 7, 8, 9])
Elements divisible by 2: tensor([0, 2, 4, 6, 8])
Elements between 3 and 7: tensor([3, 4, 5, 6, 7])

For 2D tensors:

python
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

# Filter rows where the second column is greater than 5
print(f"Rows where second column > 5:\n{matrix[matrix[:, 1] > 5]}")

# Create a filter matrix and apply it
filter_matrix = matrix > 4
print(f"Filter matrix:\n{filter_matrix}")
print(f"Elements > 4: {matrix[filter_matrix]}")

Output:

Rows where second column > 5:
tensor([[7, 8, 9]])
Filter matrix:
tensor([[False, False, False],
[False, True, True],
[True, True, True]])
Elements > 4: tensor([5, 6, 7, 8, 9])

Fancy Indexing

You can use integer arrays to select specific elements:

python
x = torch.tensor([10, 20, 30, 40, 50])

# Select specific elements using an index tensor
indices = torch.tensor([0, 2, 4])
print(f"Selected elements: {x[indices]}")

# Select elements in a different order
jumbled_indices = torch.tensor([4, 0, 2, 1])
print(f"Jumbled elements: {x[jumbled_indices]}")

Output:

Selected elements: tensor([10, 30, 50])
Jumbled elements: tensor([50, 10, 30, 20])

For multi-dimensional tensors:

python
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])

# Select specific rows
rows = torch.tensor([0, 2])
print(f"Selected rows:\n{matrix[rows]}")

# Select specific elements using row and column indices
row_indices = torch.tensor([0, 1, 2])
col_indices = torch.tensor([1, 2, 0])
print(f"Selected elements: {matrix[row_indices, col_indices]}")

Output:

Selected rows:
tensor([[1, 2, 3],
[7, 8, 9]])
Selected elements: tensor([2, 6, 7])

Modifying Tensors Through Indexing

You can use indexing to modify tensor elements:

python
# Create a tensor
x = torch.tensor([0, 1, 2, 3, 4])
print(f"Original tensor: {x}")

# Modify a single element
x[0] = 10
print(f"After modifying first element: {x}")

# Modify a slice
x[2:4] = torch.tensor([20, 30])
print(f"After modifying a slice: {x}")

# Modify with broadcasting
x[1:4] = -1
print(f"After broadcasting -1: {x}")

Output:

Original tensor: tensor([0, 1, 2, 3, 4])
After modifying first element: tensor([10, 1, 2, 3, 4])
After modifying a slice: tensor([10, 1, 20, 30, 4])
After broadcasting -1: tensor([10, -1, -1, -1, 4])

For 2D tensors:

python
matrix = torch.ones(3, 3)
print(f"Original matrix:\n{matrix}")

# Modify a row
matrix[0] = torch.tensor([5, 5, 5])
print(f"After modifying first row:\n{matrix}")

# Modify a column
matrix[:, 1] = torch.tensor([9, 9, 9])
print(f"After modifying second column:\n{matrix}")

# Modify a submatrix
matrix[1:3, 1:3] = torch.tensor([[7, 8], [7, 8]])
print(f"After modifying bottom-right submatrix:\n{matrix}")

Output:

Original matrix:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
After modifying first row:
tensor([[5., 5., 5.],
[1., 1., 1.],
[1., 1., 1.]])
After modifying second column:
tensor([[5., 9., 5.],
[1., 9., 1.],
[1., 9., 1.]])
After modifying bottom-right submatrix:
tensor([[5., 9., 5.],
[1., 7., 8.],
[1., 7., 8.]])

Practical Example: Image Processing with PyTorch

Let's look at a practical example where tensor indexing is useful. We'll work with an image represented as a tensor and perform some basic operations:

python
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import urllib.request
import io

# Download a sample image
url = "https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/pytorch-logo-flame.png"
response = urllib.request.urlopen(url)
img = Image.open(io.BytesIO(response.read()))

# Convert to PyTorch tensor (C×H×W format)
transform = transforms.ToTensor()
img_tensor = transform(img)
print(f"Image tensor shape: {img_tensor.shape}") # [C, H, W]

# Display original image
plt.figure(figsize=(10, 6))
plt.subplot(2, 2, 1)
plt.title("Original Image")
plt.imshow(transforms.ToPILImage()(img_tensor))

# Extract and display the red channel only
red_channel = img_tensor.clone()
red_channel[1:, :, :] = 0 # Set green and blue channels to 0
plt.subplot(2, 2, 2)
plt.title("Red Channel Only")
plt.imshow(transforms.ToPILImage()(red_channel))

# Crop a region of interest (center of the image)
h, w = img_tensor.shape[1], img_tensor.shape[2]
center_crop = img_tensor[:, h//4:3*h//4, w//4:3*w//4]
plt.subplot(2, 2, 3)
plt.title("Center Crop")
plt.imshow(transforms.ToPILImage()(center_crop))

# Flip the image horizontally
flipped = img_tensor[:, :, torch.arange(w-1, -1, -1)]
plt.subplot(2, 2, 4)
plt.title("Horizontal Flip")
plt.imshow(transforms.ToPILImage()(flipped))

plt.tight_layout()
plt.show()

This example demonstrates:

  1. Loading an image and converting it to a PyTorch tensor
  2. Extracting a specific color channel using indexing
  3. Cropping the image by slicing the height and width dimensions
  4. Flipping the image horizontally by using advanced indexing

Summary

In this tutorial, you've learned how to:

  • Access individual elements in PyTorch tensors
  • Slice tensors to extract ranges of data in one or more dimensions
  • Use boolean indexing to filter tensor elements based on conditions
  • Apply fancy indexing with integer arrays to select specific elements
  • Modify tensor elements through indexing
  • Apply these techniques in a practical image processing example

Mastering tensor indexing is crucial for efficient data manipulation in PyTorch. These techniques will help you prepare data for your deep learning models and analyze model outputs effectively.

Exercises

To solidify your understanding, try these exercises:

  1. Create a 3×3×3 tensor filled with random numbers and extract:

    • The center element
    • The diagonal elements
    • The corner elements
  2. Create a 10×10 matrix and:

    • Extract even-indexed rows and odd-indexed columns
    • Replace all elements greater than 0.5 with 0
    • Normalize each row to sum to 1
  3. Load an RGB image as a tensor and:

    • Convert it to grayscale (hint: average the three color channels)
    • Create a "photo negative" by subtracting each pixel value from the maximum value
    • Create a checkerboard mask for the image

Additional Resources



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)