[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.11.0

4. Scaling Graph Neural Networks

Previous: Graph Classification with Graph Neural Networks

So far, we have trained Graph Neural Networks for node classification tasks solely in a full-batch fashion. In particular, that means that every node’s hidden representation was computed in parallel and was available to re-use in the next layer.

However, once we want to operate on bigger graphs, this scheme is no longer feasible since memory consumption explodes. For example, a graph with around 10 million nodes and a hidden feature dimensionality of 128 already consumes about 5GB of GPU memory for each layer.

Hence, there has recently been some effort to let Graph Neural Networks scale to bigger graphs. One of those approaches is known as Cluster-GCN (Chiang et al. (2019), which is based on pre-partitioning the graph into subgraphs on which one can operate in a mini-batch fashion.

To showcase, let’s load the PubMed graph from the Planetoid node classification benchmark suite (Yang et al. (2016)):

[2]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('===============================================================================================================')

# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.3f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index

Dataset: PubMed():
==================
Number of graphs: 1
Number of features: 500
Number of classes: 3

Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])
===============================================================================================================
Number of nodes: 19717
Number of edges: 88648
Average node degree: 4.50
Number of training nodes: 60
Training node label rate: 0.003
Has isolated nodes: False
Has self-loops: False
Is undirected: True
Processing...
Done!

As can be seen, this graph has around 19,717 nodes. While this number of nodes should fit into GPU memory with ease, it’s nonetheless a good example to showcase how one can scale GNNs up within PyTorch Geometric.

Cluster-GCN (Chiang et al. (2019) works by first partioning the graph into subgraphs based on graph partitioning algorithms. With this, GNNs are restricted to solely convolve inside their specific subgraphs, which omits the problem of neighborhood explosion.

Screen Shot 2020-08-27 at 14.50.03.png

However, after the graph is partitioned, some links are removed which may limit the model’s performance due to a biased estimation. To address this issue, Cluster-GCN also incorporates between-cluster links inside a mini-batch, which results in the following stochastic partitioning scheme:

Screen Shot 2020-08-27 at 14.58.15.png

Here, colors represent the adjacency information that is maintained per batch (which is potentially different for every epoch).

PyTorch Geometric provides a two-stage implementation of the Cluster-GCN algorithm: 1. `ClusterData <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.ClusterData>`__ converts a Data object into a dataset of subgraphs containing num_parts partitions. 2. Given a user-defined batch_size, `ClusterLoader <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.ClusterLoader>`__ implements the stochastic partitioning scheme in order to create mini-batches.

The procedure to craft mini-batches then looks as follows:

[3]:
from torch_geometric.loader import ClusterData, ClusterLoader

torch.manual_seed(12345)

# 1. Create subgraphs.
cluster_data = ClusterData(data, num_parts=128)

# 2. Stochastic partioning scheme.
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)

print()
total_num_nodes = 0
for step, sub_data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of nodes in the current batch: {sub_data.num_nodes}')
    print(sub_data)
    print()
    total_num_nodes += sub_data.num_nodes

print(f'Iterated over {total_num_nodes} of {data.num_nodes} nodes!')

Step 1:
=======
Number of nodes in the current batch: 4928
Data(x=[4928, 500], y=[4928], train_mask=[4928], val_mask=[4928], test_mask=[4928], edge_index=[2, 16174])

Step 2:
=======
Number of nodes in the current batch: 4937
Data(x=[4937, 500], y=[4937], train_mask=[4937], val_mask=[4937], test_mask=[4937], edge_index=[2, 17832])

Step 3:
=======
Number of nodes in the current batch: 4927
Data(x=[4927, 500], y=[4927], train_mask=[4927], val_mask=[4927], test_mask=[4927], edge_index=[2, 14712])

Step 4:
=======
Number of nodes in the current batch: 4925
Data(x=[4925, 500], y=[4925], train_mask=[4925], val_mask=[4925], test_mask=[4925], edge_index=[2, 18006])

Iterated over 19717 of 19717 nodes!
Computing METIS partitioning...
Done!

Here, we partition the initial graph into 128 partitions, and use a ``batch_size`` of 32 subgraphs to form mini-batches (leaving us with 4 batches per epoch). As one can see, after a single epoch, each node has been seen exactly once.

The great thing about Cluster-GCN is that it does not complicate the GNN model implementation. Here, we can make use of the exactly same architecture introduced in the second chapter of this tutorial.

[4]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GCN(hidden_channels=16)
print(model)
GCN(
  (conv1): GCNConv(500, 16)
  (conv2): GCNConv(16, 3)
)

Training of this Graph Neural Network is then quite similar to training GNNs for the task of graph classification. Instead of operating on the graph in a full-batch fashion, we now iterate over each mini-batch, and optimize each batch independently from each other:

[5]:
# from IPython.display import Javascript
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()

      # Iterate over each mini-batch.
      for sub_data in train_loader:
          # Perform a single forward pass.
          out = model(sub_data.x, sub_data.edge_index)
          # Compute the loss solely based on the training nodes.
          loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask])
          # Derive gradients.
          loss.backward()
          # Update parameters based on gradients.
          optimizer.step()
          # Clear gradients.
          optimizer.zero_grad()

def test():
      model.eval()
      out = model(data.x, data.edge_index)
      # Use the class with highest probability.
      pred = out.argmax(dim=1)

      accs = []
      for mask in [data.train_mask, data.val_mask, data.test_mask]:
          # Check against ground-truth labels.
          correct = pred[mask] == data.y[mask]
          # Derive ratio of correct predictions.
          accs.append(int(correct.sum()) / int(mask.sum()))
      return accs

for epoch in range(1, 51):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train: 0.3333, Val Acc: 0.4160, Test Acc: 0.4070
Epoch: 002, Train: 0.7167, Val Acc: 0.5720, Test Acc: 0.5770
Epoch: 003, Train: 0.6500, Val Acc: 0.5300, Test Acc: 0.5180
Epoch: 004, Train: 0.9000, Val Acc: 0.7000, Test Acc: 0.6840
Epoch: 005, Train: 0.9333, Val Acc: 0.7400, Test Acc: 0.7090
Epoch: 006, Train: 0.9333, Val Acc: 0.7440, Test Acc: 0.7260
Epoch: 007, Train: 0.9333, Val Acc: 0.7320, Test Acc: 0.7160
Epoch: 008, Train: 0.9500, Val Acc: 0.7160, Test Acc: 0.6990
Epoch: 009, Train: 0.9500, Val Acc: 0.7480, Test Acc: 0.7200
Epoch: 010, Train: 0.9500, Val Acc: 0.7420, Test Acc: 0.7130
Epoch: 011, Train: 0.9500, Val Acc: 0.7380, Test Acc: 0.7170
Epoch: 012, Train: 0.9500, Val Acc: 0.7560, Test Acc: 0.7180
Epoch: 013, Train: 0.9667, Val Acc: 0.7640, Test Acc: 0.7250
Epoch: 014, Train: 0.9667, Val Acc: 0.7600, Test Acc: 0.7390
Epoch: 015, Train: 0.9667, Val Acc: 0.7620, Test Acc: 0.7450
Epoch: 016, Train: 0.9667, Val Acc: 0.7680, Test Acc: 0.7500
Epoch: 017, Train: 0.9667, Val Acc: 0.7740, Test Acc: 0.7490
Epoch: 018, Train: 0.9667, Val Acc: 0.7760, Test Acc: 0.7490
Epoch: 019, Train: 0.9667, Val Acc: 0.7820, Test Acc: 0.7730
Epoch: 020, Train: 0.9667, Val Acc: 0.7820, Test Acc: 0.7710
Epoch: 021, Train: 0.9667, Val Acc: 0.7840, Test Acc: 0.7720
Epoch: 022, Train: 0.9667, Val Acc: 0.7880, Test Acc: 0.7730
Epoch: 023, Train: 0.9667, Val Acc: 0.7920, Test Acc: 0.7810
Epoch: 024, Train: 0.9667, Val Acc: 0.7940, Test Acc: 0.7830
Epoch: 025, Train: 0.9667, Val Acc: 0.7900, Test Acc: 0.7780
Epoch: 026, Train: 0.9667, Val Acc: 0.7940, Test Acc: 0.7810
Epoch: 027, Train: 0.9667, Val Acc: 0.7920, Test Acc: 0.7770
Epoch: 028, Train: 0.9833, Val Acc: 0.8040, Test Acc: 0.7840
Epoch: 029, Train: 0.9833, Val Acc: 0.7980, Test Acc: 0.7830
Epoch: 030, Train: 0.9833, Val Acc: 0.8060, Test Acc: 0.7860
Epoch: 031, Train: 0.9833, Val Acc: 0.8100, Test Acc: 0.7870
Epoch: 032, Train: 0.9833, Val Acc: 0.8000, Test Acc: 0.7810
Epoch: 033, Train: 0.9833, Val Acc: 0.7820, Test Acc: 0.7770
Epoch: 034, Train: 0.9833, Val Acc: 0.7840, Test Acc: 0.7770
Epoch: 035, Train: 0.9833, Val Acc: 0.8020, Test Acc: 0.7860
Epoch: 036, Train: 0.9833, Val Acc: 0.8040, Test Acc: 0.7950
Epoch: 037, Train: 0.9833, Val Acc: 0.8020, Test Acc: 0.7930
Epoch: 038, Train: 0.9833, Val Acc: 0.8040, Test Acc: 0.7930
Epoch: 039, Train: 0.9833, Val Acc: 0.8040, Test Acc: 0.7860
Epoch: 040, Train: 0.9833, Val Acc: 0.8040, Test Acc: 0.7830
Epoch: 041, Train: 0.9833, Val Acc: 0.8060, Test Acc: 0.7930
Epoch: 042, Train: 0.9833, Val Acc: 0.8000, Test Acc: 0.7870
Epoch: 043, Train: 0.9833, Val Acc: 0.7960, Test Acc: 0.7810
Epoch: 044, Train: 0.9833, Val Acc: 0.7960, Test Acc: 0.7920
Epoch: 045, Train: 0.9833, Val Acc: 0.7940, Test Acc: 0.7910
Epoch: 046, Train: 0.9833, Val Acc: 0.7860, Test Acc: 0.7830
Epoch: 047, Train: 0.9833, Val Acc: 0.7980, Test Acc: 0.7950
Epoch: 048, Train: 0.9833, Val Acc: 0.8000, Test Acc: 0.7980
Epoch: 049, Train: 0.9833, Val Acc: 0.8000, Test Acc: 0.7830
Epoch: 050, Train: 0.9833, Val Acc: 0.8000, Test Acc: 0.7970

Conclusion

In this chapter, you have been presented a method to scale GNNs to large graphs, which otherwise would not fit into GPU memory.

This also concludes the hands-on tutorial on deep graph learning with PyTorch Geometric. If you want to learn more about GNNs or PyTorch Geometric, feel free to check out PyG’s documentation, its list of implemented methods as well as its provided examples, which cover additional topics such as link prediction, graph attention, mesh or point cloud convolutions and other methods for scaling up GNNs.

Happy hacking!