Skip to main content

PyTorch Geometric

Introduction

PyTorch Geometric (PyG) is a library built upon PyTorch that extends its capabilities to handle graph-structured data efficiently. It provides methods for deep learning on graphs and other irregular structures, like point clouds and manifolds.

Graphs are mathematical structures consisting of nodes (vertices) connected by edges, which can represent a wide variety of real-world data, such as:

  • Social networks (users as nodes, friendships as edges)
  • Molecular structures (atoms as nodes, bonds as edges)
  • Citation networks (papers as nodes, citations as edges)
  • Knowledge graphs (entities as nodes, relationships as edges)

PyTorch Geometric simplifies working with these complex data structures by providing:

  • Efficient data handling for graph structures
  • Graph neural network layers (GNN)
  • Various datasets for benchmarking
  • Utilities for graph manipulation and transformation

Let's dive in and explore how to use PyTorch Geometric in your projects!

Installation

Before we start using PyG, we need to install it:

bash
# Install PyTorch first (if not already installed)
pip install torch torchvision

# Install PyTorch Geometric
pip install torch-geometric

# Install additional dependencies (optional)
pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html

Replace ${TORCH_VERSION} and ${CUDA_VERSION} with your specific versions (e.g., 1.13.0 and cu117).

Core Concepts

1. Data Structures

The fundamental data structure in PyG is the Data class, which represents a single graph:

python
from torch_geometric.data import Data
import torch

# Creating a simple graph with 3 nodes and 2 edges
edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long).t()
x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) # Node features

data = Data(x=x, edge_index=edge_index)

print(data)
# Output: Data(x=[3, 1], edge_index=[2, 2])

In this example:

  • x is a node feature matrix of shape [num_nodes, num_node_features]
  • edge_index is a graph connectivity matrix of shape [2, num_edges]
    • The first row contains the source nodes
    • The second row contains the target nodes

2. Loading Datasets

PyG comes with many built-in datasets:

python
from torch_geometric.datasets import Planetoid

# Load the Cora dataset (a citation network)
dataset = Planetoid(root='/tmp/Cora', name='Cora')

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

# Access the graph
data = dataset[0]
print(data)

# Output:
# Dataset: Cora()
# Number of graphs: 1
# Number of features: 1433
# Number of classes: 7
# Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

3. Graph Neural Networks

Let's implement a simple Graph Convolutional Network (GCN) using PyG:

python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super(GCN, self).__init__()
# First Graph Convolutional Layer
self.conv1 = GCNConv(num_features, hidden_channels)
# Second Graph Convolutional Layer
self.conv2 = GCNConv(hidden_channels, num_classes)

def forward(self, x, edge_index):
# Apply first layer
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)

# Apply second layer
x = self.conv2(x, edge_index)

return x

Let's train this model on the Cora dataset:

python
from torch_geometric.datasets import Planetoid
import torch.optim as optim

# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

# Initialize model
model = GCN(dataset.num_features, hidden_channels=16, num_classes=dataset.num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Training function
def train():
model.train()
optimizer.zero_grad()
# Forward pass
out = model(data.x, data.edge_index)
# Compute loss only on training nodes
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
# Backward pass
loss.backward()
optimizer.step()
return loss.item()

# Evaluation function
def test():
model.eval()
out = model(data.x, data.edge_index)

# Calculate accuracy on the test set
pred = out.argmax(dim=1)
test_correct = pred[data.test_mask] == data.y[data.test_mask]
test_acc = int(test_correct.sum()) / int(data.test_mask.sum())

return test_acc

# Training loop
for epoch in range(200):
loss = train()
if epoch % 20 == 0:
test_acc = test()
print(f'Epoch: {epoch}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')

# Final evaluation
test_acc = test()
print(f'Final Test Accuracy: {test_acc:.4f}')

# Output example:
# Epoch: 0, Loss: 1.9456, Test Accuracy: 0.0950
# Epoch: 20, Loss: 1.3468, Test Accuracy: 0.6900
# Epoch: 40, Loss: 0.8952, Test Accuracy: 0.7600
# ...
# Epoch: 180, Loss: 0.4125, Test Accuracy: 0.8160
# Final Test Accuracy: 0.8190

Advanced Features

1. Graph Pooling

Graph pooling operations reduce the size of graphs by combining nodes based on certain criteria. PyG provides several pooling methods:

python
from torch_geometric.nn import global_mean_pool

class GNN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, num_classes):
super(GNN, self).__init__()
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, num_classes)

def forward(self, x, edge_index, batch):
# Node embeddings
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)

# Graph-level embeddings using mean pooling
x = global_mean_pool(x, batch)

# Final classification
x = self.lin(x)

return x

2. Graph Visualization

You can visualize your graphs using networkx:

python
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.utils import to_networkx

def visualize_graph(data, node_colors=None):
G = to_networkx(data, to_undirected=True)
plt.figure(figsize=(8, 8))
plt.axis('off')

# Default color if not provided
if node_colors is None:
node_colors = '#1f78b4'

nx.draw_networkx(
G,
pos=nx.spring_layout(G, seed=42),
with_labels=True,
node_color=node_colors,
cmap="Set2",
node_size=100,
edge_color='#888',
width=0.5,
)
plt.show()

# Example usage
data = dataset[0]
visualize_graph(data, node_colors=data.y)

3. Handling Heterogeneous Graphs

Heterogeneous graphs have different types of nodes and edges. PyG provides tools to work with them:

python
from torch_geometric.data import HeteroData

# Create a heterogeneous graph (user-item interaction network)
data = HeteroData()

# Add user nodes
data['user'].x = torch.randn(5, 16) # 5 users with 16 features each

# Add item nodes
data['item'].x = torch.randn(10, 8) # 10 items with 8 features each

# Add user-to-item edges (e.g., purchases)
data['user', 'purchases', 'item'].edge_index = torch.tensor([
[0, 1, 2, 3, 4], # User indices
[0, 2, 4, 5, 7], # Item indices
])

# Add item-to-item edges (e.g., similar items)
data['item', 'similar', 'item'].edge_index = torch.tensor([
[0, 1, 2, 3], # Source item indices
[1, 2, 3, 4], # Target item indices
])

print(data)
# Output: HeteroData(
# user={ x=[5, 16] },
# item={ x=[10, 8] },
# (user, purchases, item)={ edge_index=[2, 5] },
# (item, similar, item)={ edge_index=[2, 4] }
# )

Real-World Applications

Let's explore some practical applications of PyTorch Geometric:

1. Molecular Property Prediction

Predicting molecular properties is crucial for drug discovery. We can use GNNs to analyze molecular graphs:

python
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader

# Load the dataset (ESOL - water solubility)
dataset = MoleculeNet(root='/tmp/ESOL', name='ESOL')

# Split dataset
train_dataset = dataset[:1000]
test_dataset = dataset[1000:]

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model for regression task
class MoleculeGNN(torch.nn.Module):
def __init__(self):
super(MoleculeGNN, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, 1) # For regression output

def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)

# Global mean pooling (nodes -> graph)
x = global_mean_pool(x, batch)

# Regression head
x = self.lin(x)

return x

2. Social Network Analysis

Here's how you can use PyG for detecting communities in a social network:

python
from torch_geometric.datasets import KarateClub
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Load the famous Zachary's Karate Club dataset
dataset = KarateClub()
data = dataset[0]

# Define a GNN model
class CommunityDetector(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(CommunityDetector, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x

# Initialize and train model
model = CommunityDetector(dataset.num_features, 16, 2) # 2 communities
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out, data.y)
loss.backward()
optimizer.step()
return loss

for epoch in range(100):
loss = train()
if epoch % 10 == 0:
print(f'Epoch: {epoch}, Loss: {loss.item():.4f}')

# Get node embeddings
model.eval()
with torch.no_grad():
node_embeddings = model(data.x, data.edge_index)

# Visualize
tsne = TSNE(n_components=2)
node_embeddings_2d = tsne.fit_transform(node_embeddings.cpu().numpy())

plt.figure(figsize=(10, 8))
scatter = plt.scatter(
node_embeddings_2d[:, 0], node_embeddings_2d[:, 1],
c=data.y.cpu().numpy(), cmap='Set1', s=100, alpha=0.8
)
plt.colorbar(scatter)
plt.title("Community Visualization in Zachary's Karate Club")
plt.show()

Summary

PyTorch Geometric provides a powerful set of tools for working with graph-structured data:

  1. Data Structures: Specialized classes like Data and HeteroData for efficient graph representation
  2. Graph Neural Networks: Pre-implemented layers like GCNConv, GraphSAGE, and more
  3. Dataset Handling: Built-in datasets and utilities for processing and transforming graphs
  4. Pooling Operations: Various methods to aggregate node information into graph-level representations
  5. Advanced Features: Support for heterogeneous graphs, batching, and sampling techniques

With these capabilities, PyG enables you to solve complex problems involving relational data, from molecule property prediction to social network analysis.

Additional Resources

Exercises

  1. Basic: Create a simple graph with 5 nodes and 6 edges using PyG's Data class. Visualize it using networkx.

  2. Intermediate: Implement a Graph Attention Network (GAT) for node classification on the Cora dataset.

  3. Advanced: Create a link prediction model for a social network that predicts potential friendships between users.

  4. Challenge: Design and implement a graph neural network for traffic prediction on a road network, where nodes are intersections and edges are road segments.

Happy graph learning with PyTorch Geometric!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)