PyTorch Geometric
Introduction
PyTorch Geometric (PyG) is a library built upon PyTorch that extends its capabilities to handle graph-structured data efficiently. It provides methods for deep learning on graphs and other irregular structures, like point clouds and manifolds.
Graphs are mathematical structures consisting of nodes (vertices) connected by edges, which can represent a wide variety of real-world data, such as:
- Social networks (users as nodes, friendships as edges)
- Molecular structures (atoms as nodes, bonds as edges)
- Citation networks (papers as nodes, citations as edges)
- Knowledge graphs (entities as nodes, relationships as edges)
PyTorch Geometric simplifies working with these complex data structures by providing:
- Efficient data handling for graph structures
- Graph neural network layers (GNN)
- Various datasets for benchmarking
- Utilities for graph manipulation and transformation
Let's dive in and explore how to use PyTorch Geometric in your projects!
Installation
Before we start using PyG, we need to install it:
# Install PyTorch first (if not already installed)
pip install torch torchvision
# Install PyTorch Geometric
pip install torch-geometric
# Install additional dependencies (optional)
pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html
Replace ${TORCH_VERSION}
and ${CUDA_VERSION}
with your specific versions (e.g., 1.13.0
and cu117
).
Core Concepts
1. Data Structures
The fundamental data structure in PyG is the Data
class, which represents a single graph:
from torch_geometric.data import Data
import torch
# Creating a simple graph with 3 nodes and 2 edges
edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long).t()
x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) # Node features
data = Data(x=x, edge_index=edge_index)
print(data)
# Output: Data(x=[3, 1], edge_index=[2, 2])
In this example:
x
is a node feature matrix of shape[num_nodes, num_node_features]
edge_index
is a graph connectivity matrix of shape[2, num_edges]
- The first row contains the source nodes
- The second row contains the target nodes
2. Loading Datasets
PyG comes with many built-in datasets:
from torch_geometric.datasets import Planetoid
# Load the Cora dataset (a citation network)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")
# Access the graph
data = dataset[0]
print(data)
# Output:
# Dataset: Cora()
# Number of graphs: 1
# Number of features: 1433
# Number of classes: 7
# Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
3. Graph Neural Networks
Let's implement a simple Graph Convolutional Network (GCN) using PyG:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super(GCN, self).__init__()
# First Graph Convolutional Layer
self.conv1 = GCNConv(num_features, hidden_channels)
# Second Graph Convolutional Layer
self.conv2 = GCNConv(hidden_channels, num_classes)
def forward(self, x, edge_index):
# Apply first layer
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
# Apply second layer
x = self.conv2(x, edge_index)
return x
Let's train this model on the Cora dataset:
from torch_geometric.datasets import Planetoid
import torch.optim as optim
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Initialize model
model = GCN(dataset.num_features, hidden_channels=16, num_classes=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training function
def train():
model.train()
optimizer.zero_grad()
# Forward pass
out = model(data.x, data.edge_index)
# Compute loss only on training nodes
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
# Backward pass
loss.backward()
optimizer.step()
return loss.item()
# Evaluation function
def test():
model.eval()
out = model(data.x, data.edge_index)
# Calculate accuracy on the test set
pred = out.argmax(dim=1)
test_correct = pred[data.test_mask] == data.y[data.test_mask]
test_acc = int(test_correct.sum()) / int(data.test_mask.sum())
return test_acc
# Training loop
for epoch in range(200):
loss = train()
if epoch % 20 == 0:
test_acc = test()
print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
# Final evaluation
test_acc = test()
print(f'Final Test Accuracy: {test_acc:.4f}')
# Output example:
# Epoch: 0, Loss: 1.9456, Test Accuracy: 0.0950
# Epoch: 20, Loss: 1.3468, Test Accuracy: 0.6900
# Epoch: 40, Loss: 0.8952, Test Accuracy: 0.7600
# ...
# Epoch: 180, Loss: 0.4125, Test Accuracy: 0.8160
# Final Test Accuracy: 0.8190
Advanced Features
1. Graph Pooling
Graph pooling operations reduce the size of graphs by combining nodes based on certain criteria. PyG provides several pooling methods:
from torch_geometric.nn import global_mean_pool
class GNN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super(GNN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index, batch):
# Node embeddings
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
# Graph-level embeddings using mean pooling
x = global_mean_pool(x, batch)
# Final classification
x = self.lin(x)
return x
2. Graph Visualization
You can visualize your graphs using networkx:
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx
def visualize_graph(data, node_colors=None):
G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
# Default color if not provided
if node_colors is None:
node_colors = '#1f78b4'
nx.draw_networkx(
G,
pos=nx.spring_layout(G, seed=42),
with_labels=True,
node_color=node_colors,
cmap="Set2",
node_size=100,
edge_color='#888',
width=0.5,
)
plt.show()
# Example usage
data = dataset[0]
visualize_graph(data, node_colors=data.y)
3. Handling Heterogeneous Graphs
Heterogeneous graphs have different types of nodes and edges. PyG provides tools to work with them:
from torch_geometric.data import HeteroData
# Create a heterogeneous graph (user-item interaction network)
data = HeteroData()
# Add user nodes
data['user'].x = torch.randn(5, 16) # 5 users with 16 features each
# Add item nodes
data['item'].x = torch.randn(10, 8) # 10 items with 8 features each
# Add user-to-item edges (e.g., purchases)
data['user', 'purchases', 'item'].edge_index = torch.tensor([
[0, 1, 2, 3, 4], # User indices
[0, 2, 4, 5, 7], # Item indices
])
# Add item-to-item edges (e.g., similar items)
data['item', 'similar', 'item'].edge_index = torch.tensor([
[0, 1, 2, 3], # Source item indices
[1, 2, 3, 4], # Target item indices
])
print(data)
# Output: HeteroData(
# user={ x=[5, 16] },
# item={ x=[10, 8] },
# (user, purchases, item)={ edge_index=[2, 5] },
# (item, similar, item)={ edge_index=[2, 4] }
# )
Real-World Applications
Let's explore some practical applications of PyTorch Geometric:
1. Molecular Property Prediction
Predicting molecular properties is crucial for drug discovery. We can use GNNs to analyze molecular graphs:
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
# Load the dataset (ESOL - water solubility)
dataset = MoleculeNet(root='/tmp/ESOL', name='ESOL')
# Split dataset
train_dataset = dataset[:1000]
test_dataset = dataset[1000:]
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Model for regression task
class MoleculeGNN(torch.nn.Module):
def __init__(self):
super(MoleculeGNN, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, 1) # For regression output
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
# Global mean pooling (nodes -> graph)
x = global_mean_pool(x, batch)
# Regression head
x = self.lin(x)
return x
2. Social Network Analysis
Here's how you can use PyG for detecting communities in a social network:
from torch_geometric.datasets import KarateClub
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# Load the famous Zachary's Karate Club dataset
dataset = KarateClub()
data = dataset[0]
# Define a GNN model
class CommunityDetector(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(CommunityDetector, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
# Initialize and train model
model = CommunityDetector(dataset.num_features, 16, 2) # 2 communities
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out, data.y)
loss.backward()
optimizer.step()
return loss
for epoch in range(100):
loss = train()
if epoch % 10 == 0:
print(f'Epoch: {epoch}, Loss: {loss.item():.4f}')
# Get node embeddings
model.eval()
with torch.no_grad():
node_embeddings = model(data.x, data.edge_index)
# Visualize
tsne = TSNE(n_components=2)
node_embeddings_2d = tsne.fit_transform(node_embeddings.cpu().numpy())
plt.figure(figsize=(10, 8))
scatter = plt.scatter(
node_embeddings_2d[:, 0], node_embeddings_2d[:, 1],
c=data.y.cpu().numpy(), cmap='Set1', s=100, alpha=0.8
)
plt.colorbar(scatter)
plt.title("Community Visualization in Zachary's Karate Club")
plt.show()
Summary
PyTorch Geometric provides a powerful set of tools for working with graph-structured data:
- Data Structures: Specialized classes like
Data
andHeteroData
for efficient graph representation - Graph Neural Networks: Pre-implemented layers like GCNConv, GraphSAGE, and more
- Dataset Handling: Built-in datasets and utilities for processing and transforming graphs
- Pooling Operations: Various methods to aggregate node information into graph-level representations
- Advanced Features: Support for heterogeneous graphs, batching, and sampling techniques
With these capabilities, PyG enables you to solve complex problems involving relational data, from molecule property prediction to social network analysis.
Additional Resources
- PyTorch Geometric Official Documentation
- PyG Paper - The original research paper
- PyG Examples Repository
- Graph Neural Networks: A Review of Methods and Applications
Exercises
-
Basic: Create a simple graph with 5 nodes and 6 edges using PyG's
Data
class. Visualize it using networkx. -
Intermediate: Implement a Graph Attention Network (GAT) for node classification on the Cora dataset.
-
Advanced: Create a link prediction model for a social network that predicts potential friendships between users.
-
Challenge: Design and implement a graph neural network for traffic prediction on a road network, where nodes are intersections and edges are road segments.
Happy graph learning with PyTorch Geometric!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)