PyTorch Graph Neural Networks
Introduction
Graph Neural Networks (GNNs) are a powerful class of deep learning models designed specifically to work with graph-structured data. Unlike traditional neural networks that process data in Euclidean space (like images or sequences), GNNs can operate on data represented as graphs - consisting of nodes (vertices) and edges (connections between nodes).
In this tutorial, we'll explore how to implement GNNs using PyTorch, particularly with the PyTorch Geometric (PyG) library, which provides ready-to-use implementations of popular GNN architectures.
Graphs are everywhere in real life:
- Social networks (users as nodes, friendships as edges)
- Molecules (atoms as nodes, bonds as edges)
- Citation networks (papers as nodes, citations as edges)
- Recommendation systems (users and items as nodes, interactions as edges)
Prerequisites
Before diving into GNNs, you should have:
- Basic knowledge of PyTorch
- Understanding of neural networks fundamentals
- Familiarity with Python programming
Setting Up the Environment
First, let's install the necessary libraries:
pip install torch torch_geometric
Let's verify the installation and import the required libraries:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAT, GraphSAGE
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
Expected output:
PyTorch version: 1.10.0
PyTorch Geometric version: 2.0.4
Understanding Graph Data Structure
In PyTorch Geometric, a graph is represented by a Data
object with several important attributes:
x
: Node features (shape: [num_nodes, num_node_features])edge_index
: Graph connectivity in COO format (shape: [2, num_edges])edge_attr
: Edge features (shape: [num_edges, num_edge_features])y
: Target labels (shape: [num_nodes] for node-level or [1] for graph-level tasks)
Let's create a simple graph with three nodes and three edges:
# Create node features: 3 nodes, each with 4 features
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], dtype=torch.float)
# Define the edges (in COO format)
# Each column represents an edge: [source_node, target_node]
edge_index = torch.tensor([[0, 1, 1],
[1, 0, 2]], dtype=torch.long)
# Create the graph
graph = Data(x=x, edge_index=edge_index)
print(graph)
print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimensions: {graph.num_node_features}")
Expected output:
Data(x=[3, 4], edge_index=[2, 3])
Number of nodes: 3
Number of edges: 3
Node feature dimensions: 4
Graph Neural Network Basics
The core idea behind GNNs is message passing: nodes exchange information with their neighbors to update their representations. After multiple rounds of message passing, each node's representation captures both its own features and the structural information of its neighborhood.
The message passing framework consists of three main steps:
- Message: Compute messages from source nodes to target nodes
- Aggregate: Combine messages from neighboring nodes
- Update: Update node representations based on aggregated messages
Implementing a Simple Graph Convolutional Network (GCN)
Let's implement a basic GCN for node classification:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
# First Graph Convolutional layer
self.conv1 = GCNConv(num_node_features, 16)
# Second Graph Convolutional layer
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
# First layer with ReLU activation
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
# Second layer
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Working with a Real Dataset: Cora Citation Network
Let's use the Cora dataset, a citation network where nodes represent scientific papers and edges represent citations between papers:
from torch_geometric.datasets import Planetoid
# Load the Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Get the first graph
print(f"Dataset: {dataset}:")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Has isolated nodes: {data.has_isolated_nodes()}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")
Expected output:
Dataset: Cora():
Number of graphs: 1
Number of nodes: 2708
Number of edges: 10556
Number of features: 1433
Number of classes: 7
Has isolated nodes: False
Has self-loops: False
Is undirected: True
Now let's train our GCN model on the Cora dataset:
import torch.optim as optim
# Initialize the model
model = GCN(num_node_features=dataset.num_features,
num_classes=dataset.num_classes)
# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# Training function
def train():
model.train()
optimizer.zero_grad()
# Forward pass
out = model(data)
# Calculate loss (only on training nodes)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
# Backward pass
loss.backward()
optimizer.step()
return loss.item()
# Evaluation function
def test():
model.eval()
with torch.no_grad():
out = model(data)
# For validation data
val_loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask]).item()
val_acc = (out[data.val_mask].argmax(dim=1) == data.y[data.val_mask]).float().mean().item()
# For test data
test_loss = F.nll_loss(out[data.test_mask], data.y[data.test_mask]).item()
test_acc = (out[data.test_mask].argmax(dim=1) == data.y[data.test_mask]).float().mean().item()
return val_loss, val_acc, test_loss, test_acc
# Train the model
best_val_acc = 0
best_test_acc = 0
for epoch in range(200):
loss = train()
val_loss, val_acc, test_loss, test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
print(f'Best validation accuracy: {best_val_acc:.4f}')
print(f'Best test accuracy: {best_test_acc:.4f}')
Example output:
Epoch 000, Loss: 1.9457, Val Acc: 0.2220, Test Acc: 0.2130
Epoch 010, Loss: 0.6734, Val Acc: 0.5860, Test Acc: 0.5620
Epoch 020, Loss: 0.2451, Val Acc: 0.7480, Test Acc: 0.7670
...
Epoch 190, Loss: 0.0358, Val Acc: 0.8140, Test Acc: 0.8230
Best validation accuracy: 0.8140
Best test accuracy: 0.8230
More Advanced GNN Architectures
1. Graph Attention Networks (GAT)
GAT uses attention mechanisms to weigh the importance of neighbor nodes:
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(GAT, self).__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
2. GraphSAGE
GraphSAGE is designed for inductive learning on large graphs:
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(num_features, 16)
self.conv2 = SAGEConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Real-World Application: Molecular Property Prediction
One of the most exciting applications of GNNs is in chemistry for predicting molecular properties. Let's use the MUTAG dataset, which contains molecular graphs where the task is to predict whether a molecule is mutagenic:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
# Load MUTAG dataset
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
# Split into training and test sets
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
# Graph level GNN for graph classification
class GNNForGraphClassification(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GNNForGraphClassification, self).__init__()
self.conv1 = GCNConv(num_node_features, 32)
self.conv2 = GCNConv(32, 64)
self.fc = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
# Node embeddings
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv2(x, edge_index)
# Global pooling: average all node features for each graph
x = torch_geometric.nn.global_mean_pool(x, batch)
# Fully connected layer for classification
x = self.fc(x)
return F.log_softmax(x, dim=1)
# Initialize model, optimizer, and loss function
model = GNNForGraphClassification(num_node_features=dataset.num_node_features,
num_classes=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train_molecular_model():
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(train_dataset)
def test_molecular_model(loader):
model.eval()
correct = 0
for data in loader:
with torch.no_grad():
pred = model(data).max(dim=1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)
# Training loop
for epoch in range(50):
loss = train_molecular_model()
train_acc = test_molecular_model(train_loader)
test_acc = test_molecular_model(test_loader)
if epoch % 5 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Example output:
Epoch: 000, Loss: 0.6824, Train Acc: 0.6333, Test Acc: 0.5417
Epoch: 005, Loss: 0.3218, Train Acc: 0.8533, Test Acc: 0.7083
Epoch: 010, Loss: 0.2190, Train Acc: 0.9333, Test Acc: 0.7500
...
Epoch: 045, Loss: 0.0650, Train Acc: 0.9867, Test Acc: 0.8333
Visualizing Node Embeddings
Let's visualize the node embeddings learned by our GCN model on the Cora dataset:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# Get node embeddings from the trained model
def get_embeddings(model, data):
model.eval()
with torch.no_grad():
# Get the output of the first convolutional layer
x, edge_index = data.x, data.edge_index
embeddings = model.conv1(x, edge_index)
embeddings = embeddings.detach().numpy()
return embeddings
# Get node classes for coloring
node_classes = data.y.numpy()
# Get embeddings from our model
embeddings = get_embeddings(model, data)
# Use t-SNE for dimensionality reduction
tsne = TSNE(n_components=2, random_state=42)
node_embeddings_2d = tsne.fit_transform(embeddings)
# Plot the embeddings
plt.figure(figsize=(10, 8))
scatter = plt.scatter(node_embeddings_2d[:, 0], node_embeddings_2d[:, 1],
c=node_classes, cmap='tab10', s=50, alpha=0.8)
plt.colorbar(scatter)
plt.title('t-SNE Visualization of Node Embeddings')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.show()
Summary
In this tutorial, we've explored Graph Neural Networks with PyTorch:
- We introduced the concept of GNNs and why they're essential for graph data
- We learned how to represent graph data in PyTorch Geometric
- We implemented several GNN architectures:
- Graph Convolutional Networks (GCN)
- Graph Attention Networks (GAT)
- GraphSAGE
- We applied GNNs to real-world problems:
- Node classification with the Cora citation network
- Graph classification for molecular property prediction
- We visualized the learned node embeddings
GNNs are incredibly powerful for processing graph-structured data and are increasingly being used in various domains like social network analysis, drug discovery, traffic prediction, and recommendation systems.
Additional Resources and Exercises
Resources for Further Learning
- PyTorch Geometric Documentation
- Stanford CS224W: Machine Learning with Graphs
- Graph Neural Networks: A Review of Methods and Applications
Practice Exercises
-
Node Classification: Use a different GNN architecture (GAT or GraphSAGE) on the Cora dataset and compare the results with our GCN model.
-
Graph Classification: Try to improve the molecular property prediction model by:
- Using a different GNN architecture
- Adding more layers or changing the hidden dimensions
- Implementing different pooling strategies (max pooling, attention-based pooling)
-
Graph Generation: Research and implement a Graph Variational Autoencoder (GVAE) for generating new molecular structures.
-
Link Prediction: Implement a link prediction task on a social network dataset to predict future connections between users.
-
Custom Dataset: Create a custom graph dataset from your field of interest and apply GNNs to solve a specific problem.
Remember that GNNs are a rapidly evolving field with new architectures and applications emerging regularly. Stay curious and keep exploring!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)