PyTorch Tensor Indexing
Introduction
When working with PyTorch tensors, you often need to access specific elements, rows, columns, or subsets of data. This process, known as indexing, is fundamental for data manipulation in deep learning applications. In this tutorial, we'll explore the various ways to index and slice PyTorch tensors—skills that will help you efficiently manipulate data for your machine learning models.
PyTorch's indexing syntax is largely based on NumPy's, so if you're familiar with NumPy arrays, you'll find many similarities. However, there are some PyTorch-specific features worth learning.
Basic Tensor Indexing
Let's start by creating a tensor and accessing individual elements.
import torch
# Create a 1D tensor
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
print(f"Original tensor: {x}")
# Access single elements
print(f"First element: {x[0]}")
print(f"Last element: {x[-1]}")
print(f"Third element: {x[2]}")
Output:
Original tensor: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
First element: tensor(0)
Last element: tensor(9)
Third element: tensor(2)
For multi-dimensional tensors, we can specify indices for each dimension:
# Create a 2D tensor (matrix)
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
print(f"Original matrix:\n{matrix}")
# Access elements
print(f"Element at row 0, column 1: {matrix[0, 1]}")
print(f"Element at row 2, column 2: {matrix[2, 2]}")
Output:
Original matrix:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Element at row 0, column 1: tensor(2)
Element at row 2, column 2: tensor(9)
Slicing Tensors
Slicing allows you to extract a range of elements from a tensor. The syntax is tensor[start:end:step]
, where:
start
is the starting index (inclusive, defaults to 0)end
is the ending index (exclusive, defaults to the size of dimension)step
is the step size (defaults to 1)
1D Tensor Slicing
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# Get elements from index 2 to 5 (exclusive)
print(f"Elements from index 2 to 5: {x[2:5]}")
# Get every second element
print(f"Every second element: {x[::2]}")
# Get elements from index 1 to 7 with step 2
print(f"Elements 1 to 7 with step 2: {x[1:8:2]}")
# Negative indexing (from the end)
print(f"Last 3 elements: {x[-3:]}")
print(f"All elements except the last 2: {x[:-2]}")
# Reverse the tensor
print(f"Reversed tensor: {x[::-1]}")
Output:
Elements from index 2 to 5: tensor([2, 3, 4])
Every second element: tensor([0, 2, 4, 6, 8])
Elements 1 to 7 with step 2: tensor([1, 3, 5, 7])
Last 3 elements: tensor([7, 8, 9])
All elements except the last 2: tensor([0, 1, 2, 3, 4, 5, 6, 7])
Reversed tensor: tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
2D Tensor Slicing
With multi-dimensional tensors, you can slice along each dimension separately:
matrix = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
print(f"Original matrix:\n{matrix}")
# Get the first row
print(f"First row: {matrix[0]}")
# Get the first column
print(f"First column: {matrix[:, 0]}")
# Get a 2x2 submatrix (top-left corner)
print(f"Top-left 2x2 submatrix:\n{matrix[0:2, 0:2]}")
# Get rows 0 and 2, and all columns
print(f"Rows 0 and 2:\n{matrix[[0, 2]]}")
# Get alternate rows and columns
print(f"Alternate rows and columns:\n{matrix[::2, ::2]}")
Output:
Original matrix:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
First row: tensor([1, 2, 3, 4])
First column: tensor([1, 5, 9])
Top-left 2x2 submatrix:
tensor([[1, 2],
[5, 6]])
Rows 0 and 2:
tensor([[ 1, 2, 3, 4],
[ 9, 10, 11, 12]])
Alternate rows and columns:
tensor([[ 1, 3],
[ 9, 11]])
Advanced Indexing Techniques
PyTorch provides more advanced indexing methods for complex operations.
Boolean Indexing
You can use boolean masks to filter tensor elements:
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# Create a boolean mask
mask = x > 5
print(f"Boolean mask: {mask}")
# Apply the mask to filter elements
filtered = x[mask]
print(f"Elements > 5: {filtered}")
# One-line filtering
print(f"Elements divisible by 2: {x[x % 2 == 0]}")
print(f"Elements between 3 and 7: {x[(x >= 3) & (x <= 7)]}")
Output:
Boolean mask: tensor([False, False, False, False, False, False, True, True, True, True])
Elements > 5: tensor([6, 7, 8, 9])
Elements divisible by 2: tensor([0, 2, 4, 6, 8])
Elements between 3 and 7: tensor([3, 4, 5, 6, 7])
For 2D tensors:
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Filter rows where the second column is greater than 5
print(f"Rows where second column > 5:\n{matrix[matrix[:, 1] > 5]}")
# Create a filter matrix and apply it
filter_matrix = matrix > 4
print(f"Filter matrix:\n{filter_matrix}")
print(f"Elements > 4: {matrix[filter_matrix]}")
Output:
Rows where second column > 5:
tensor([[7, 8, 9]])
Filter matrix:
tensor([[False, False, False],
[False, True, True],
[True, True, True]])
Elements > 4: tensor([5, 6, 7, 8, 9])
Fancy Indexing
You can use integer arrays to select specific elements:
x = torch.tensor([10, 20, 30, 40, 50])
# Select specific elements using an index tensor
indices = torch.tensor([0, 2, 4])
print(f"Selected elements: {x[indices]}")
# Select elements in a different order
jumbled_indices = torch.tensor([4, 0, 2, 1])
print(f"Jumbled elements: {x[jumbled_indices]}")
Output:
Selected elements: tensor([10, 30, 50])
Jumbled elements: tensor([50, 10, 30, 20])
For multi-dimensional tensors:
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Select specific rows
rows = torch.tensor([0, 2])
print(f"Selected rows:\n{matrix[rows]}")
# Select specific elements using row and column indices
row_indices = torch.tensor([0, 1, 2])
col_indices = torch.tensor([1, 2, 0])
print(f"Selected elements: {matrix[row_indices, col_indices]}")
Output:
Selected rows:
tensor([[1, 2, 3],
[7, 8, 9]])
Selected elements: tensor([2, 6, 7])
Modifying Tensors Through Indexing
You can use indexing to modify tensor elements:
# Create a tensor
x = torch.tensor([0, 1, 2, 3, 4])
print(f"Original tensor: {x}")
# Modify a single element
x[0] = 10
print(f"After modifying first element: {x}")
# Modify a slice
x[2:4] = torch.tensor([20, 30])
print(f"After modifying a slice: {x}")
# Modify with broadcasting
x[1:4] = -1
print(f"After broadcasting -1: {x}")
Output:
Original tensor: tensor([0, 1, 2, 3, 4])
After modifying first element: tensor([10, 1, 2, 3, 4])
After modifying a slice: tensor([10, 1, 20, 30, 4])
After broadcasting -1: tensor([10, -1, -1, -1, 4])
For 2D tensors:
matrix = torch.ones(3, 3)
print(f"Original matrix:\n{matrix}")
# Modify a row
matrix[0] = torch.tensor([5, 5, 5])
print(f"After modifying first row:\n{matrix}")
# Modify a column
matrix[:, 1] = torch.tensor([9, 9, 9])
print(f"After modifying second column:\n{matrix}")
# Modify a submatrix
matrix[1:3, 1:3] = torch.tensor([[7, 8], [7, 8]])
print(f"After modifying bottom-right submatrix:\n{matrix}")
Output:
Original matrix:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
After modifying first row:
tensor([[5., 5., 5.],
[1., 1., 1.],
[1., 1., 1.]])
After modifying second column:
tensor([[5., 9., 5.],
[1., 9., 1.],
[1., 9., 1.]])
After modifying bottom-right submatrix:
tensor([[5., 9., 5.],
[1., 7., 8.],
[1., 7., 8.]])
Practical Example: Image Processing with PyTorch
Let's look at a practical example where tensor indexing is useful. We'll work with an image represented as a tensor and perform some basic operations:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import urllib.request
import io
# Download a sample image
url = "https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/pytorch-logo-flame.png"
response = urllib.request.urlopen(url)
img = Image.open(io.BytesIO(response.read()))
# Convert to PyTorch tensor (C×H×W format)
transform = transforms.ToTensor()
img_tensor = transform(img)
print(f"Image tensor shape: {img_tensor.shape}") # [C, H, W]
# Display original image
plt.figure(figsize=(10, 6))
plt.subplot(2, 2, 1)
plt.title("Original Image")
plt.imshow(transforms.ToPILImage()(img_tensor))
# Extract and display the red channel only
red_channel = img_tensor.clone()
red_channel[1:, :, :] = 0 # Set green and blue channels to 0
plt.subplot(2, 2, 2)
plt.title("Red Channel Only")
plt.imshow(transforms.ToPILImage()(red_channel))
# Crop a region of interest (center of the image)
h, w = img_tensor.shape[1], img_tensor.shape[2]
center_crop = img_tensor[:, h//4:3*h//4, w//4:3*w//4]
plt.subplot(2, 2, 3)
plt.title("Center Crop")
plt.imshow(transforms.ToPILImage()(center_crop))
# Flip the image horizontally
flipped = img_tensor[:, :, torch.arange(w-1, -1, -1)]
plt.subplot(2, 2, 4)
plt.title("Horizontal Flip")
plt.imshow(transforms.ToPILImage()(flipped))
plt.tight_layout()
plt.show()
This example demonstrates:
- Loading an image and converting it to a PyTorch tensor
- Extracting a specific color channel using indexing
- Cropping the image by slicing the height and width dimensions
- Flipping the image horizontally by using advanced indexing
Summary
In this tutorial, you've learned how to:
- Access individual elements in PyTorch tensors
- Slice tensors to extract ranges of data in one or more dimensions
- Use boolean indexing to filter tensor elements based on conditions
- Apply fancy indexing with integer arrays to select specific elements
- Modify tensor elements through indexing
- Apply these techniques in a practical image processing example
Mastering tensor indexing is crucial for efficient data manipulation in PyTorch. These techniques will help you prepare data for your deep learning models and analyze model outputs effectively.
Exercises
To solidify your understanding, try these exercises:
-
Create a 3×3×3 tensor filled with random numbers and extract:
- The center element
- The diagonal elements
- The corner elements
-
Create a 10×10 matrix and:
- Extract even-indexed rows and odd-indexed columns
- Replace all elements greater than 0.5 with 0
- Normalize each row to sum to 1
-
Load an RGB image as a tensor and:
- Convert it to grayscale (hint: average the three color channels)
- Create a "photo negative" by subtracting each pixel value from the maximum value
- Create a checkerboard mask for the image
Additional Resources
- PyTorch Documentation on Tensors
- NumPy Indexing Documentation (many concepts apply to PyTorch)
- PyTorch Tutorial: Tensor Operations
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)