Skip to main content

PyTorch Graph Neural Networks

Introduction

Graph Neural Networks (GNNs) are a powerful class of deep learning models designed specifically to work with graph-structured data. Unlike traditional neural networks that process data in Euclidean space (like images or sequences), GNNs can operate on data represented as graphs - consisting of nodes (vertices) and edges (connections between nodes).

In this tutorial, we'll explore how to implement GNNs using PyTorch, particularly with the PyTorch Geometric (PyG) library, which provides ready-to-use implementations of popular GNN architectures.

Graphs are everywhere in real life:

  • Social networks (users as nodes, friendships as edges)
  • Molecules (atoms as nodes, bonds as edges)
  • Citation networks (papers as nodes, citations as edges)
  • Recommendation systems (users and items as nodes, interactions as edges)

Prerequisites

Before diving into GNNs, you should have:

  • Basic knowledge of PyTorch
  • Understanding of neural networks fundamentals
  • Familiarity with Python programming

Setting Up the Environment

First, let's install the necessary libraries:

bash
pip install torch torch_geometric

Let's verify the installation and import the required libraries:

python
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAT, GraphSAGE

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

Expected output:

PyTorch version: 1.10.0
PyTorch Geometric version: 2.0.4

Understanding Graph Data Structure

In PyTorch Geometric, a graph is represented by a Data object with several important attributes:

  • x: Node features (shape: [num_nodes, num_node_features])
  • edge_index: Graph connectivity in COO format (shape: [2, num_edges])
  • edge_attr: Edge features (shape: [num_edges, num_edge_features])
  • y: Target labels (shape: [num_nodes] for node-level or [1] for graph-level tasks)

Let's create a simple graph with three nodes and three edges:

python
# Create node features: 3 nodes, each with 4 features
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]], dtype=torch.float)

# Define the edges (in COO format)
# Each column represents an edge: [source_node, target_node]
edge_index = torch.tensor([[0, 1, 1],
[1, 0, 2]], dtype=torch.long)

# Create the graph
graph = Data(x=x, edge_index=edge_index)

print(graph)
print(f"Number of nodes: {graph.num_nodes}")
print(f"Number of edges: {graph.num_edges}")
print(f"Node feature dimensions: {graph.num_node_features}")

Expected output:

Data(x=[3, 4], edge_index=[2, 3])
Number of nodes: 3
Number of edges: 3
Node feature dimensions: 4

Graph Neural Network Basics

The core idea behind GNNs is message passing: nodes exchange information with their neighbors to update their representations. After multiple rounds of message passing, each node's representation captures both its own features and the structural information of its neighborhood.

The message passing framework consists of three main steps:

  1. Message: Compute messages from source nodes to target nodes
  2. Aggregate: Combine messages from neighboring nodes
  3. Update: Update node representations based on aggregated messages

Implementing a Simple Graph Convolutional Network (GCN)

Let's implement a basic GCN for node classification:

python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
# First Graph Convolutional layer
self.conv1 = GCNConv(num_node_features, 16)
# Second Graph Convolutional layer
self.conv2 = GCNConv(16, num_classes)

def forward(self, data):
x, edge_index = data.x, data.edge_index

# First layer with ReLU activation
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)

# Second layer
x = self.conv2(x, edge_index)

return F.log_softmax(x, dim=1)

Working with a Real Dataset: Cora Citation Network

Let's use the Cora dataset, a citation network where nodes represent scientific papers and edges represent citations between papers:

python
from torch_geometric.datasets import Planetoid

# Load the Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Get the first graph

print(f"Dataset: {dataset}:")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Has isolated nodes: {data.has_isolated_nodes()}")
print(f"Has self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")

Expected output:

Dataset: Cora():
Number of graphs: 1
Number of nodes: 2708
Number of edges: 10556
Number of features: 1433
Number of classes: 7
Has isolated nodes: False
Has self-loops: False
Is undirected: True

Now let's train our GCN model on the Cora dataset:

python
import torch.optim as optim

# Initialize the model
model = GCN(num_node_features=dataset.num_features,
num_classes=dataset.num_classes)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Training function
def train():
model.train()
optimizer.zero_grad()
# Forward pass
out = model(data)
# Calculate loss (only on training nodes)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
# Backward pass
loss.backward()
optimizer.step()
return loss.item()

# Evaluation function
def test():
model.eval()
with torch.no_grad():
out = model(data)
# For validation data
val_loss = F.nll_loss(out[data.val_mask], data.y[data.val_mask]).item()
val_acc = (out[data.val_mask].argmax(dim=1) == data.y[data.val_mask]).float().mean().item()

# For test data
test_loss = F.nll_loss(out[data.test_mask], data.y[data.test_mask]).item()
test_acc = (out[data.test_mask].argmax(dim=1) == data.y[data.test_mask]).float().mean().item()

return val_loss, val_acc, test_loss, test_acc

# Train the model
best_val_acc = 0
best_test_acc = 0

for epoch in range(200):
loss = train()
val_loss, val_acc, test_loss, test_acc = test()

if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

if epoch % 10 == 0:
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'Best validation accuracy: {best_val_acc:.4f}')
print(f'Best test accuracy: {best_test_acc:.4f}')

Example output:

Epoch 000, Loss: 1.9457, Val Acc: 0.2220, Test Acc: 0.2130
Epoch 010, Loss: 0.6734, Val Acc: 0.5860, Test Acc: 0.5620
Epoch 020, Loss: 0.2451, Val Acc: 0.7480, Test Acc: 0.7670
...
Epoch 190, Loss: 0.0358, Val Acc: 0.8140, Test Acc: 0.8230
Best validation accuracy: 0.8140
Best test accuracy: 0.8230

More Advanced GNN Architectures

1. Graph Attention Networks (GAT)

GAT uses attention mechanisms to weigh the importance of neighbor nodes:

python
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(GAT, self).__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

def forward(self, data):
x, edge_index = data.x, data.edge_index

x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)

return F.log_softmax(x, dim=1)

2. GraphSAGE

GraphSAGE is designed for inductive learning on large graphs:

python
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(num_features, 16)
self.conv2 = SAGEConv(16, num_classes)

def forward(self, data):
x, edge_index = data.x, data.edge_index

x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)

return F.log_softmax(x, dim=1)

Real-World Application: Molecular Property Prediction

One of the most exciting applications of GNNs is in chemistry for predicting molecular properties. Let's use the MUTAG dataset, which contains molecular graphs where the task is to predict whether a molecule is mutagenic:

python
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

# Load MUTAG dataset
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')

# Split into training and test sets
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Graph level GNN for graph classification
class GNNForGraphClassification(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GNNForGraphClassification, self).__init__()
self.conv1 = GCNConv(num_node_features, 32)
self.conv2 = GCNConv(32, 64)
self.fc = torch.nn.Linear(64, num_classes)

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch

# Node embeddings
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv2(x, edge_index)

# Global pooling: average all node features for each graph
x = torch_geometric.nn.global_mean_pool(x, batch)

# Fully connected layer for classification
x = self.fc(x)

return F.log_softmax(x, dim=1)

# Initialize model, optimizer, and loss function
model = GNNForGraphClassification(num_node_features=dataset.num_node_features,
num_classes=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train_molecular_model():
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(train_dataset)

def test_molecular_model(loader):
model.eval()
correct = 0
for data in loader:
with torch.no_grad():
pred = model(data).max(dim=1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)

# Training loop
for epoch in range(50):
loss = train_molecular_model()
train_acc = test_molecular_model(train_loader)
test_acc = test_molecular_model(test_loader)

if epoch % 5 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Example output:

Epoch: 000, Loss: 0.6824, Train Acc: 0.6333, Test Acc: 0.5417
Epoch: 005, Loss: 0.3218, Train Acc: 0.8533, Test Acc: 0.7083
Epoch: 010, Loss: 0.2190, Train Acc: 0.9333, Test Acc: 0.7500
...
Epoch: 045, Loss: 0.0650, Train Acc: 0.9867, Test Acc: 0.8333

Visualizing Node Embeddings

Let's visualize the node embeddings learned by our GCN model on the Cora dataset:

python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# Get node embeddings from the trained model
def get_embeddings(model, data):
model.eval()
with torch.no_grad():
# Get the output of the first convolutional layer
x, edge_index = data.x, data.edge_index
embeddings = model.conv1(x, edge_index)
embeddings = embeddings.detach().numpy()
return embeddings

# Get node classes for coloring
node_classes = data.y.numpy()

# Get embeddings from our model
embeddings = get_embeddings(model, data)

# Use t-SNE for dimensionality reduction
tsne = TSNE(n_components=2, random_state=42)
node_embeddings_2d = tsne.fit_transform(embeddings)

# Plot the embeddings
plt.figure(figsize=(10, 8))
scatter = plt.scatter(node_embeddings_2d[:, 0], node_embeddings_2d[:, 1],
c=node_classes, cmap='tab10', s=50, alpha=0.8)
plt.colorbar(scatter)
plt.title('t-SNE Visualization of Node Embeddings')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.show()

Summary

In this tutorial, we've explored Graph Neural Networks with PyTorch:

  1. We introduced the concept of GNNs and why they're essential for graph data
  2. We learned how to represent graph data in PyTorch Geometric
  3. We implemented several GNN architectures:
    • Graph Convolutional Networks (GCN)
    • Graph Attention Networks (GAT)
    • GraphSAGE
  4. We applied GNNs to real-world problems:
    • Node classification with the Cora citation network
    • Graph classification for molecular property prediction
  5. We visualized the learned node embeddings

GNNs are incredibly powerful for processing graph-structured data and are increasingly being used in various domains like social network analysis, drug discovery, traffic prediction, and recommendation systems.

Additional Resources and Exercises

Resources for Further Learning

Practice Exercises

  1. Node Classification: Use a different GNN architecture (GAT or GraphSAGE) on the Cora dataset and compare the results with our GCN model.

  2. Graph Classification: Try to improve the molecular property prediction model by:

    • Using a different GNN architecture
    • Adding more layers or changing the hidden dimensions
    • Implementing different pooling strategies (max pooling, attention-based pooling)
  3. Graph Generation: Research and implement a Graph Variational Autoencoder (GVAE) for generating new molecular structures.

  4. Link Prediction: Implement a link prediction task on a social network dataset to predict future connections between users.

  5. Custom Dataset: Create a custom graph dataset from your field of interest and apply GNNs to solve a specific problem.

Remember that GNNs are a rapidly evolving field with new architectures and applications emerging regularly. Stay curious and keep exploring!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)