[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.11.0
3. Graph Classification with Graph Neural Networks¶
Previous: Node Classification with Graph Neural Networks
In this tutorial session we will have a closer look at how to apply Graph Neural Networks (GNNs) to the task of graph classification. Graph classification refers to the problem of classifiying entire graphs (in contrast to nodes), given a dataset of graphs, based on some structural graph properties. Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand.
The most common task for graph classification is molecular property prediction, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not.
The TU Dortmund University has collected a wide range of different graph classification datasets, known as the TUDatasets, which are also accessible via `torch_geometric.datasets.TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.TUDataset>`__ in PyTorch Geometric. Let’s load and inspect one of the smaller ones, the MUTAG dataset:
[2]:
import torch
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('=============================================================')
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Dataset: MUTAG(188):
====================
Number of graphs: 188
Number of features: 7
Number of classes: 2
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
=============================================================
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True
Extracting data/TUDataset/MUTAG/MUTAG.zip
Processing...
Done!
This dataset provides 188 different graphs, and the task is to classify each graph into one out of two classes.
By inspecting the first graph object of the dataset, we can see that it comes with 17 nodes (with 7-dimensional feature vectors) and 38 edges (leading to an average node degree of 2.24). It also comes with exactly one graph label (y=[1]), and, in addition to previous datasets, provides addtional 4-dimensional edge features (edge_attr=[38, 4]). However, for the sake of simplicity, we will not make use of those.
PyTorch Geometric provides some useful utilities for working with graph datasets, e.g., we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing:
[3]:
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
Number of training graphs: 150
Number of test graphs: 38
Mini-batching of graphs¶
Since graphs in graph classification datasets are usually small, a good idea is to batch the graphs before inputting them into a Graph Neural Network to guarantee full GPU utilization. In the image or language domain, this procedure is typically achieved by rescaling or padding each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. The length of this dimension is then equal to the number of examples grouped in a mini-batch and is
typically referred to as the batch_size.
However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. Therefore, PyTorch Geometric opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension:
This procedure has some crucial advantages over other batching procedures:
GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs.
There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, i.e., the edges.
PyTorch Geometric automatically takes care of batching multiple graphs into a single giant graph with the help of the `torch_geometric.data.DataLoader <https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader>`__ class:
[9]:
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for step, data in enumerate(train_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
Step 1:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2682], x=[1211, 7], edge_attr=[2682, 4], y=[64], batch=[1211], ptr=[65])
Step 2:
=======
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2484], x=[1128, 7], edge_attr=[2484, 4], y=[64], batch=[1128], ptr=[65])
Step 3:
=======
Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 828], x=[375, 7], edge_attr=[828, 4], y=[22], batch=[375], ptr=[23])
Here, we opt for a batch_size of 64, leading to 3 (randomly shuffled) mini-batches, containing all \(2 \cdot 64+22 = 150\) graphs.
Furthermore, each Batch object is equipped with a ``batch`` vector, which maps each node to its respective graph in the batch:
Training a Graph Neural Network (GNN)¶
Training a GNN for graph classification usually follows a simple recipe:
Embed each node by performing multiple rounds of message passing
Aggregate node embeddings into a unified graph embedding (readout layer)
Train a final classifier on the graph embedding
There exists multiple readout layers in literature, but the most common one is to simply take the average of node embeddings:
PyTorch Geometric provides this functionality via `torch_geometric.nn.global_mean_pool <https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_mean_pool>`__, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector batch to compute a graph embedding of size [batch_size, hidden_channels] for each graph in the batch.
The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training:
[11]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
# 1. Obtain node embeddings
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
# 2. Readout layer
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GCN(hidden_channels=64)
print(model)
GCN(
(conv1): GCNConv(7, 64)
(conv2): GCNConv(64, 64)
(conv3): GCNConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
Here, we again make use of the `GCNConv <https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv>`__ with \(\mathrm{ReLU}(x) = \max(x, 0)\) activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer.
Let’s train our network for a few epochs to see how well it performs on the training as well as test set:
[12]:
# from IPython.display import Javascript
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
for data in train_loader: # Iterate in batches over the training dataset.
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.
loss = criterion(out, data.y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
def test(loader):
model.eval()
correct = 0
for data in loader: # Iterate in batches over the training/test dataset.
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data.y).sum()) # Check against ground-truth labels.
return correct / len(loader.dataset) # Derive ratio of correct predictions.
for epoch in range(1, 171):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7133, Test Acc: 0.7632
Epoch: 008, Train Acc: 0.6867, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 011, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7133, Test Acc: 0.7895
Epoch: 013, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 015, Train Acc: 0.7333, Test Acc: 0.7632
Epoch: 016, Train Acc: 0.7333, Test Acc: 0.7895
Epoch: 017, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 018, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 019, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 020, Train Acc: 0.7200, Test Acc: 0.8684
Epoch: 021, Train Acc: 0.7400, Test Acc: 0.7895
Epoch: 022, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 023, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 024, Train Acc: 0.7600, Test Acc: 0.7895
Epoch: 025, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 026, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 027, Train Acc: 0.7600, Test Acc: 0.7895
Epoch: 028, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 029, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 030, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 031, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 032, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 033, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 034, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 035, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 036, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 037, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 038, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 039, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 040, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 041, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 042, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 043, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 044, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 045, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 046, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 047, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 048, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 049, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 050, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 051, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 052, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 053, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 054, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 055, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 056, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 057, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 058, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 059, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 060, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 061, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 062, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 063, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 064, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 065, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 066, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 067, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 068, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 069, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 070, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 071, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 072, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 073, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 074, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 075, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 076, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 077, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 078, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 079, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 080, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 081, Train Acc: 0.7933, Test Acc: 0.8158
Epoch: 082, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 083, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 084, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 085, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 086, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 087, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 088, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 089, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 090, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 091, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 092, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 093, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 094, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 095, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 096, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 097, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 098, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 099, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 100, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 101, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 102, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 103, Train Acc: 0.7600, Test Acc: 0.7368
Epoch: 104, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 105, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 106, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 107, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 108, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 109, Train Acc: 0.8000, Test Acc: 0.8158
Epoch: 110, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 111, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 112, Train Acc: 0.7733, Test Acc: 0.7105
Epoch: 113, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 114, Train Acc: 0.8200, Test Acc: 0.7105
Epoch: 115, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 116, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 117, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 118, Train Acc: 0.7667, Test Acc: 0.7368
Epoch: 119, Train Acc: 0.8000, Test Acc: 0.8158
Epoch: 120, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 121, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 122, Train Acc: 0.7933, Test Acc: 0.7632
Epoch: 123, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 124, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 125, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 126, Train Acc: 0.7600, Test Acc: 0.7105
Epoch: 127, Train Acc: 0.7667, Test Acc: 0.7368
Epoch: 128, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 129, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 130, Train Acc: 0.7600, Test Acc: 0.7368
Epoch: 131, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 132, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 133, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 134, Train Acc: 0.8133, Test Acc: 0.7105
Epoch: 135, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 136, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 137, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 138, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 139, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 140, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 141, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 142, Train Acc: 0.7533, Test Acc: 0.7105
Epoch: 143, Train Acc: 0.7600, Test Acc: 0.7368
Epoch: 144, Train Acc: 0.7933, Test Acc: 0.7632
Epoch: 145, Train Acc: 0.8000, Test Acc: 0.8158
Epoch: 146, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 147, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 148, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 149, Train Acc: 0.8067, Test Acc: 0.7895
Epoch: 150, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 151, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 152, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 153, Train Acc: 0.7933, Test Acc: 0.7632
Epoch: 154, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 155, Train Acc: 0.7600, Test Acc: 0.7368
Epoch: 156, Train Acc: 0.7667, Test Acc: 0.7368
Epoch: 157, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 158, Train Acc: 0.7800, Test Acc: 0.7368
Epoch: 159, Train Acc: 0.7667, Test Acc: 0.7368
Epoch: 160, Train Acc: 0.8200, Test Acc: 0.6842
Epoch: 161, Train Acc: 0.7867, Test Acc: 0.7368
Epoch: 162, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 163, Train Acc: 0.7733, Test Acc: 0.7105
Epoch: 164, Train Acc: 0.8067, Test Acc: 0.7632
Epoch: 165, Train Acc: 0.8000, Test Acc: 0.7368
Epoch: 166, Train Acc: 0.7933, Test Acc: 0.7368
Epoch: 167, Train Acc: 0.7933, Test Acc: 0.7105
Epoch: 168, Train Acc: 0.8000, Test Acc: 0.6316
Epoch: 169, Train Acc: 0.7733, Test Acc: 0.7105
Epoch: 170, Train Acc: 0.7733, Test Acc: 0.7895
As one can see, our model reaches around 76% test accuracy. Reasons for the fluctations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.
(Optional) Exercise¶
Can we do better than this? As multiple papers pointed out (Xu et al. (2018), Morris et al. (2018)), applying neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures. An alternative formulation (Morris et al. (2018)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information:
This layer is implemented under the name `GraphConv <https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GraphConv>`__ in PyTorch Geometric.
As an exercise, you are invited to complete the following code to the extent that it makes use of PyG’s GraphConv rather than GCNConv. This should bring you close to 82% test accuracy.
[13]:
from torch_geometric.nn import GraphConv
class GNN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GNN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
self.conv2 = GraphConv(hidden_channels, hidden_channels)
self.conv3 = GraphConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = global_mean_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GNN(hidden_channels=64)
print(model)
GNN(
(conv1): GraphConv(7, 64)
(conv2): GraphConv(64, 64)
(conv3): GraphConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
[14]:
# from IPython.display import Javascript
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1, 201):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
GNN(
(conv1): GraphConv(7, 64)
(conv2): GraphConv(64, 64)
(conv3): GraphConv(64, 64)
(lin): Linear(in_features=64, out_features=2, bias=True)
)
Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6800, Test Acc: 0.7632
Epoch: 007, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 008, Train Acc: 0.7200, Test Acc: 0.8158
Epoch: 009, Train Acc: 0.7533, Test Acc: 0.8158
Epoch: 010, Train Acc: 0.7933, Test Acc: 0.8421
Epoch: 011, Train Acc: 0.7533, Test Acc: 0.8158
Epoch: 012, Train Acc: 0.7667, Test Acc: 0.8684
Epoch: 013, Train Acc: 0.7733, Test Acc: 0.8421
Epoch: 014, Train Acc: 0.8000, Test Acc: 0.8684
Epoch: 015, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 016, Train Acc: 0.7800, Test Acc: 0.8421
Epoch: 017, Train Acc: 0.8067, Test Acc: 0.8421
Epoch: 018, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 019, Train Acc: 0.7933, Test Acc: 0.8158
Epoch: 020, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 021, Train Acc: 0.8000, Test Acc: 0.8158
Epoch: 022, Train Acc: 0.8067, Test Acc: 0.8158
Epoch: 023, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 024, Train Acc: 0.8067, Test Acc: 0.8158
Epoch: 025, Train Acc: 0.8067, Test Acc: 0.8158
Epoch: 026, Train Acc: 0.8000, Test Acc: 0.7368
Epoch: 027, Train Acc: 0.8533, Test Acc: 0.7895
Epoch: 028, Train Acc: 0.8533, Test Acc: 0.7895
Epoch: 029, Train Acc: 0.8200, Test Acc: 0.8158
Epoch: 030, Train Acc: 0.8533, Test Acc: 0.7895
Epoch: 031, Train Acc: 0.8600, Test Acc: 0.7895
Epoch: 032, Train Acc: 0.8200, Test Acc: 0.7895
Epoch: 033, Train Acc: 0.8533, Test Acc: 0.7895
Epoch: 034, Train Acc: 0.8200, Test Acc: 0.7895
Epoch: 035, Train Acc: 0.8733, Test Acc: 0.7632
Epoch: 036, Train Acc: 0.8867, Test Acc: 0.7895
Epoch: 037, Train Acc: 0.8267, Test Acc: 0.8421
Epoch: 038, Train Acc: 0.8667, Test Acc: 0.7368
Epoch: 039, Train Acc: 0.8200, Test Acc: 0.8421
Epoch: 040, Train Acc: 0.8533, Test Acc: 0.7895
Epoch: 041, Train Acc: 0.8667, Test Acc: 0.8158
Epoch: 042, Train Acc: 0.8267, Test Acc: 0.7895
Epoch: 043, Train Acc: 0.8800, Test Acc: 0.7632
Epoch: 044, Train Acc: 0.9000, Test Acc: 0.7895
Epoch: 045, Train Acc: 0.8200, Test Acc: 0.8684
Epoch: 046, Train Acc: 0.9000, Test Acc: 0.7632
Epoch: 047, Train Acc: 0.8933, Test Acc: 0.7895
Epoch: 048, Train Acc: 0.9067, Test Acc: 0.7895
Epoch: 049, Train Acc: 0.9000, Test Acc: 0.7895
Epoch: 050, Train Acc: 0.8933, Test Acc: 0.7895
Epoch: 051, Train Acc: 0.8867, Test Acc: 0.7632
Epoch: 052, Train Acc: 0.9067, Test Acc: 0.8158
Epoch: 053, Train Acc: 0.9133, Test Acc: 0.8158
Epoch: 054, Train Acc: 0.9200, Test Acc: 0.7895
Epoch: 055, Train Acc: 0.9000, Test Acc: 0.8421
Epoch: 056, Train Acc: 0.9133, Test Acc: 0.7895
Epoch: 057, Train Acc: 0.8800, Test Acc: 0.7632
Epoch: 058, Train Acc: 0.9067, Test Acc: 0.8158
Epoch: 059, Train Acc: 0.9133, Test Acc: 0.8158
Epoch: 060, Train Acc: 0.9200, Test Acc: 0.7632
Epoch: 061, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 062, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 063, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 064, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 065, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 066, Train Acc: 0.9200, Test Acc: 0.8158
Epoch: 067, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 068, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 069, Train Acc: 0.9200, Test Acc: 0.7895
Epoch: 070, Train Acc: 0.9200, Test Acc: 0.8421
Epoch: 071, Train Acc: 0.9200, Test Acc: 0.8421
Epoch: 072, Train Acc: 0.9533, Test Acc: 0.8421
Epoch: 073, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 074, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 075, Train Acc: 0.9400, Test Acc: 0.8684
Epoch: 076, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 077, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 078, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 079, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 080, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 081, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 082, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 083, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 084, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 085, Train Acc: 0.9533, Test Acc: 0.8158
Epoch: 086, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 087, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 088, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 089, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 090, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 091, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 092, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 093, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 094, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 095, Train Acc: 0.9267, Test Acc: 0.7895
Epoch: 096, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 097, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 098, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 099, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 100, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 101, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 102, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 103, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 104, Train Acc: 0.9067, Test Acc: 0.7368
Epoch: 105, Train Acc: 0.9200, Test Acc: 0.7632
Epoch: 106, Train Acc: 0.9467, Test Acc: 0.7632
Epoch: 107, Train Acc: 0.9200, Test Acc: 0.7632
Epoch: 108, Train Acc: 0.9133, Test Acc: 0.8158
Epoch: 109, Train Acc: 0.9333, Test Acc: 0.7632
Epoch: 110, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 111, Train Acc: 0.9533, Test Acc: 0.7895
Epoch: 112, Train Acc: 0.9333, Test Acc: 0.7632
Epoch: 113, Train Acc: 0.9200, Test Acc: 0.7632
Epoch: 114, Train Acc: 0.9400, Test Acc: 0.7632
Epoch: 115, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 116, Train Acc: 0.9333, Test Acc: 0.7895
Epoch: 117, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 118, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 119, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 120, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 121, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 122, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 123, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 124, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 125, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 126, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 127, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 128, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 129, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 130, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 131, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 132, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 133, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 134, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 135, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 136, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 137, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 138, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 139, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 140, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 141, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 142, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 143, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 144, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 145, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 146, Train Acc: 0.9333, Test Acc: 0.7895
Epoch: 147, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 148, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 149, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 150, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 151, Train Acc: 0.9400, Test Acc: 0.7632
Epoch: 152, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 153, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 154, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 155, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 156, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 157, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 158, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 159, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 160, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 161, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 162, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 163, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 164, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 165, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 166, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 167, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 168, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 169, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 170, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 171, Train Acc: 0.9333, Test Acc: 0.7368
Epoch: 172, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 173, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 174, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 175, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 176, Train Acc: 0.9333, Test Acc: 0.7895
Epoch: 177, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 178, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 179, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 180, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 181, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 182, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 183, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 184, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 185, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 186, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 187, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 188, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 189, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 190, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 191, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 192, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 193, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 194, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 195, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 196, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 197, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 198, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 199, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 200, Train Acc: 0.9467, Test Acc: 0.8421
Conclusion¶
In this chapter, you have learned how to apply GNNs to the task of graph classification. You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings.
In the next session, you will learn how you can utilize PyTorch Geometric to let Graph Neural Networks scale to single large graphs.