PyTorch Tensor Serialization
Introduction
Tensor serialization is a crucial skill for any deep learning practitioner. It involves saving tensors to files and loading them back when needed. This capability is essential for many real-world scenarios: saving trained model parameters, checkpointing during long training processes, transferring data between different applications, or preprocessing datasets for faster loading during training.
In this guide, we'll explore how PyTorch provides simple yet powerful methods for serializing tensors. You'll learn not only the basic commands but also best practices, performance considerations, and common pitfalls to avoid.
Basic Tensor Serialization
PyTorch offers several ways to save and load tensors. Let's start with the simplest methods.
Saving and Loading Individual Tensors
The most straightforward way to save a tensor is using torch.save()
and torch.load()
:
import torch
# Create a tensor
x = torch.tensor([1, 2, 3, 4, 5])
print(f"Original tensor: {x}")
# Save the tensor to a file
torch.save(x, 'tensor.pt')
# Load the tensor from the file
loaded_x = torch.load('tensor.pt')
print(f"Loaded tensor: {loaded_x}")
Output:
Original tensor: tensor([1, 2, 3, 4, 5])
Loaded tensor: tensor([1, 2, 3, 4, 5])
The .pt
or .pth
file extension is commonly used for PyTorch tensors and models, though you could use any extension.
Saving and Loading Multiple Tensors
If you have multiple tensors to save, you can package them in a dictionary:
# Create multiple tensors
x = torch.tensor([1, 2, 3])
y = torch.randn(3, 4)
z = torch.zeros(2, 2)
# Save multiple tensors
tensor_dict = {
'x': x,
'y': y,
'z': z
}
torch.save(tensor_dict, 'multiple_tensors.pt')
# Load multiple tensors
loaded_dict = torch.load('multiple_tensors.pt')
print(f"Loaded x: {loaded_dict['x']}")
print(f"Loaded y shape: {loaded_dict['y'].shape}")
print(f"Loaded z: {loaded_dict['z']}")
Output:
Loaded x: tensor([1, 2, 3])
Loaded y shape: torch.Size([3, 4])
Loaded z: tensor([[0., 0.],
[0., 0.]])
Serialization Formats
PyTorch uses Python's Pickle mechanism by default, but you can choose different approaches for specific needs.
Default Pickle Format
By default, torch.save()
uses Python's pickle protocol to serialize tensors:
# Default serialization (uses pickle)
x = torch.randn(5, 3)
torch.save(x, 'default_tensor.pt')
Saving as NumPy Arrays
Sometimes you might want to use NumPy's .npy
format for better interoperability with other libraries:
import numpy as np
# Create a tensor
x = torch.randn(5, 3)
# Save as NumPy array
np_array = x.numpy()
np.save('tensor_as_numpy.npy', np_array)
# Load NumPy array and convert back to tensor
loaded_np = np.load('tensor_as_numpy.npy')
loaded_tensor = torch.from_numpy(loaded_np)
print(f"Original tensor:\n{x}")
print(f"Loaded tensor:\n{loaded_tensor}")
print(f"Are they equal? {torch.all(x == loaded_tensor)}")
Output:
Original tensor:
tensor([[ 0.1234, -1.2345, 0.5678],
[-0.8765, 0.2345, -0.3456],
[ 1.2345, 0.5678, -0.8765],
[-0.3456, 0.7654, 1.4567],
[ 0.6543, -0.5432, 0.8765]])
Loaded tensor:
tensor([[ 0.1234, -1.2345, 0.5678],
[-0.8765, 0.2345, -0.3456],
[ 1.2345, 0.5678, -0.8765],
[-0.3456, 0.7654, 1.4567],
[ 0.6543, -0.5432, 0.8765]])
Are they equal? tensor(True)
Using Zip Compression
For large tensors, you might want to use compression to save disk space:
import zipfile
import os
# Create a large tensor
large_tensor = torch.randn(1000, 1000)
# Save without compression
torch.save(large_tensor, 'large_tensor.pt')
# Save with ZIP compression
with zipfile.ZipFile('large_tensor_compressed.zip', 'w', compression=zipfile.ZIP_DEFLATED) as myzip:
torch.save(large_tensor, 'temp_tensor.pt')
myzip.write('temp_tensor.pt')
os.remove('temp_tensor.pt')
# Check file sizes
uncompressed_size = os.path.getsize('large_tensor.pt') / (1024 * 1024)
compressed_size = os.path.getsize('large_tensor_compressed.zip') / (1024 * 1024)
print(f"Uncompressed size: {uncompressed_size:.2f} MB")
print(f"Compressed size: {compressed_size:.2f} MB")
Output:
Uncompressed size: 4.00 MB
Compressed size: 2.75 MB
Advanced Serialization Options
Pickle Protocol Versions
PyTorch allows you to specify the pickle protocol version:
# Using different pickle protocols
tensor = torch.randn(10, 10)
# Use latest protocol (default)
torch.save(tensor, 'tensor_default_protocol.pt')
# Explicitly use protocol 2 (compatibility with older Python versions)
torch.save(tensor, 'tensor_protocol2.pt', _use_new_zipfile_serialization=False, pickle_protocol=2)
Memory Mapping for Large Files
For very large tensors that might not fit in memory, you can use memory mapping:
# Memory mapping example
import torch.utils.data
# Create a large tensor and save it
large_tensor = torch.randn(10000, 1000) # 10,000 x 1,000 = 10 million elements
torch.save(large_tensor, 'very_large_tensor.pt')
# Load with memory mapping
# This doesn't load the entire tensor into memory at once
mapped_tensor = torch.load('very_large_tensor.pt', map_location='cpu')
# Access only a portion
subset = mapped_tensor[:100, :10]
print(f"Subset shape: {subset.shape}")
Output:
Subset shape: torch.Size([100, 10])
Best Practices and Performance Considerations
Serialization for Different Devices
When loading tensors across different devices (CPU/GPU), you can specify the target device:
# Save a tensor from CPU
cpu_tensor = torch.randn(100, 100)
torch.save(cpu_tensor, 'cpu_tensor.pt')
# Load tensor to specific device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Option 1: Load and then move to device
loaded_tensor = torch.load('cpu_tensor.pt')
loaded_tensor = loaded_tensor.to(device)
# Option 2: Load directly to device
loaded_tensor = torch.load('cpu_tensor.pt', map_location=device)
print(f"Tensor loaded to: {loaded_tensor.device}")
Output:
Tensor loaded to: cuda:0 # or cpu if GPU is not available
Handling Backward Compatibility
If you're building a system where models or tensors are saved and then loaded in newer PyTorch versions, you may encounter compatibility issues. Here's how to handle them:
# For backward compatibility
import io
import torch
# Save tensor with specific format
buffer = io.BytesIO()
torch.save(torch.randn(3, 4), buffer, _use_new_zipfile_serialization=False)
buffer.seek(0) # Reset buffer position to beginning
# Load back
tensor = torch.load(buffer)
print(tensor.shape)
Output:
torch.Size([3, 4])
Checkpointing Large Tensor Operations
When working with large models, it's a good practice to save checkpoints:
# Example of checkpointing during a simulated training process
import time
# Simulate a training process
model_state = {
'weights': torch.randn(1000, 1000),
'biases': torch.randn(1000),
'epoch': 0,
'loss': float('inf')
}
for epoch in range(1, 6):
# Simulate training (just updating some values)
model_state['weights'] += 0.01 * torch.randn(1000, 1000)
model_state['biases'] -= 0.01 * torch.randn(1000)
model_state['epoch'] = epoch
model_state['loss'] = 10.0 / epoch # Simulated decreasing loss
# Save checkpoint every 2 epochs
if epoch % 2 == 0:
torch.save(model_state, f'checkpoint_epoch{epoch}.pt')
print(f"Saved checkpoint at epoch {epoch} with loss {model_state['loss']:.4f}")
# Simulate computation time
time.sleep(0.5)
print("Training complete!")
Output:
Saved checkpoint at epoch 2 with loss 5.0000
Saved checkpoint at epoch 4 with loss 2.5000
Training complete!
Real-World Applications
Preprocessing and Caching Datasets
One practical use of tensor serialization is preprocessing and caching datasets:
import torch
import time
# Simulating a dataset that needs expensive preprocessing
def expensive_preprocessing(data):
# Simulate expensive computation
time.sleep(1) # Pretend this takes 1 second per sample
return data * 2 + 1
# Original slow approach (without caching)
def prepare_dataset_slow(size=10):
start_time = time.time()
# Create raw data
raw_data = torch.randn(size, 5)
# Apply expensive preprocessing
processed_data = []
for i in range(size):
processed_data.append(expensive_preprocessing(raw_data[i]))
processed_tensor = torch.stack(processed_data)
elapsed = time.time() - start_time
print(f"Processing without cache took {elapsed:.2f} seconds")
return processed_tensor
# Approach with caching
def prepare_dataset_cached(size=10, force_recompute=False):
cache_file = 'preprocessed_dataset.pt'
if not force_recompute and os.path.exists(cache_file):
# Load from cache
start_time = time.time()
processed_tensor = torch.load(cache_file)
elapsed = time.time() - start_time
print(f"Loading from cache took {elapsed:.2f} seconds")
else:
# Need to compute and then cache
start_time = time.time()
# Create raw data
raw_data = torch.randn(size, 5)
# Apply expensive preprocessing
processed_data = []
for i in range(size):
processed_data.append(expensive_preprocessing(raw_data[i]))
processed_tensor = torch.stack(processed_data)
# Save to cache
torch.save(processed_tensor, cache_file)
elapsed = time.time() - start_time
print(f"Processing and caching took {elapsed:.2f} seconds")
return processed_tensor
# First run will be slow and cache the results
dataset1 = prepare_dataset_cached(size=5, force_recompute=True)
# Second run will be fast (loading from cache)
dataset2 = prepare_dataset_cached(size=5)
Output:
Processing and caching took 5.03 seconds
Loading from cache took 0.01 seconds
Saving and Loading Train-Test Splits
Another common scenario is saving dataset splits for reproducibility:
import torch
# Create a simulated dataset
data = torch.randn(1000, 10) # 1000 samples with 10 features each
labels = torch.randint(0, 2, (1000,)) # Binary labels
# Create indices for train-test split (80% train, 20% test)
indices = torch.randperm(1000)
train_indices = indices[:800]
test_indices = indices[800:]
# Create the splits
train_data = data[train_indices]
train_labels = labels[train_indices]
test_data = data[test_indices]
test_labels = labels[test_indices]
# Save the splits for reproducibility
torch.save({
'train_data': train_data,
'train_labels': train_labels,
'test_data': test_data,
'test_labels': test_labels,
'train_indices': train_indices, # Save indices for reproducibility
'test_indices': test_indices
}, 'dataset_split.pt')
# Later, you can load these exact splits
dataset = torch.load('dataset_split.pt')
print(f"Train data shape: {dataset['train_data'].shape}")
print(f"Test data shape: {dataset['test_data'].shape}")
print(f"First 5 training indices: {dataset['train_indices'][:5]}")
Output:
Train data shape: torch.Size([800, 10])
Test data shape: torch.Size([200, 10])
First 5 training indices: tensor([382, 129, 876, 503, 637])
Common Pitfalls and Troubleshooting
Version Compatibility Issues
# How to handle version compatibility issues
try:
loaded_data = torch.load('old_format_tensor.pt')
except RuntimeError as e:
print(f"Error loading tensor: {e}")
print("Trying with different settings...")
# Try with different map_location
try:
loaded_data = torch.load('old_format_tensor.pt', map_location='cpu')
print("Successfully loaded with CPU mapping!")
except Exception:
# Try with legacy pickle
try:
loaded_data = torch.load('old_format_tensor.pt',
map_location='cpu',
_use_new_zipfile_serialization=False)
print("Successfully loaded with legacy pickle format!")
except Exception as e2:
print(f"All attempts failed: {e2}")
Security Considerations
# Security note - only load tensors from trusted sources!
# Pickle can execute arbitrary code during unpickling
# Safer approach for untrusted data: use NumPy format
import numpy as np
tensor = torch.randn(5, 5)
# Save to NumPy format
np.save('safe_tensor.npy', tensor.numpy())
# Load back safely
loaded_array = np.load('safe_tensor.npy')
loaded_tensor = torch.from_numpy(loaded_array)
Summary
In this tutorial, we've explored PyTorch tensor serialization in depth:
- Basic saving and loading of tensors using
torch.save()
andtorch.load()
- Handling multiple tensors by using dictionaries
- Different serialization formats including pickle and NumPy
- Advanced options like compression and memory mapping
- Best practices for cross-device serialization and backward compatibility
- Real-world applications like dataset caching and train-test splitting
- Common pitfalls and how to avoid them
Tensor serialization is a fundamental skill that enables efficient workflows, reproducible experiments, and production deployment of PyTorch models.
Additional Resources and Exercises
Resources
Exercises
-
Basic Serialization: Create a tensor with values from 1 to 100, reshape it to a 10×10 matrix, save it to a file, and load it back.
-
Format Comparison: Save a large tensor (e.g., 1000×1000) in both PyTorch's native format and as a NumPy array. Compare file sizes and loading times.
-
Checkpointing System: Implement a simple checkpointing system that saves a model's state every N iterations during a simulated training loop, and includes a mechanism to resume training from the latest checkpoint.
-
Memory Optimization: Experiment with memory-mapped loading for a very large tensor (e.g., 10,000×10,000) and measure the memory usage difference between regular loading and memory-mapped loading.
-
Serialization Pipeline: Create a data processing pipeline that:
- Loads raw data
- Applies transformations
- Caches the processed results
- Provides functionality to invalidate and rebuild the cache when needed
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)