[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.12.1+cu113
|████████████████████████████████| 7.9 MB 13.7 MB/s
|████████████████████████████████| 3.5 MB 9.4 MB/s
Building wheel for torch-geometric (setup.py) ... done
7. Customizing Aggregations within Message Passing with torch_geometric.nn.aggr¶
Aggregation functions play an important role in the message passing framework and the readout function when implementing GNNs. Many works in the GNN literature (Hamilton et al. (2017), Xu et al. (2018), Corso et al. (2020), Li et al. (2020)), demonstrate that the
choice of aggregation functions contributes significantly to the performance of GNN models. In particular, the performance of GNNs with different aggregation functions differs when applied to distinct tasks and datasets. Recent works also show that using multiple aggregations (Corso et al. (2020)) and learnable aggregations (Li et al. (2020)) can potentially gain
substantial improvements. To facilitate experimentation with these different aggregation schemes and unify concepts of aggregation within GNNs across both `MessagePassing <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py>`__ and global readouts, we provide modular and re-usable aggregations in the newly defined torch_geometric.nn.aggr.* package.
Unifying these concepts also helps us to perform optimization and specialized implementations in a single place. In the new integration, the following functionality is applicable:
# Original interface with string type as aggregation argument
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr="mean")
# Use a single aggregation module as aggregation argument
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=MeanAggregation())
# Use a list of aggregation strings as aggregation argument
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=['mean', 'max', 'sum', 'std', 'var'])
# Use a list of aggregation modules as aggregation argument
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=[
MeanAggregation(),
MaxAggregation(),
SumAggregation(),
StdAggregation(),
VarAggregation(),
])
# Use a list of mixed modules and strings as aggregation argument
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=[
'mean',
MaxAggregation(),
'sum',
StdAggregation(),
'var',
])
# Define multiple learnable aggregations with keyword arguments
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=['softmax', 'softmax', 'softmax'],
aggr_kwargs = dict(aggrs_kwargs=[
dict(t=0.1, learn=True),
dict(t=1, learn=True),
dict(t=10, learn=True)]))
# Define multiple aggregations with `MultiAggregation` module
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=MultiAggregation([
SoftmaxAggregation(t=0.1, learn=True),
SoftmaxAggregation(t=1, learn=True),
SoftmaxAggregation(t=10, learn=True)]))
In this tutorial, we explore the new aggregation package with SAGEConv (Hamilton et al. (2017)) and ClusterLoader (Chiang et al. (2019)) and showcase on the PubMed graph from the Planetoid node classification benchmark suite (Yang et al. (2016)).
Loading the dataset¶
[2]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('==================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('===============================================================================================================')
from torch_geometric.loader import ClusterData, ClusterLoader
torch.manual_seed(12345)
cluster_data = ClusterData(data, num_parts=128) # 1. Create subgraphs.
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True) # 2. Stochastic partioning scheme.
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!
Computing METIS partitioning...
Dataset: PubMed():
==================
Number of graphs: 1
Number of features: 500
Number of classes: 3
Data(x=[19717, 500], edge_index=[2, 88648], y=[19717], train_mask=[19717], val_mask=[19717], test_mask=[19717])
===============================================================================================================
Done!
Define train, test and run functions¶
[3]:
criterion = torch.nn.CrossEntropyLoss()
def train(model):
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for sub_data in train_loader: # Iterate over each mini-batch.
optimizer.zero_grad() # Clear gradients.
out = model(sub_data.x, sub_data.edge_index) # Perform a single forward pass.
loss = criterion(out[sub_data.train_mask], sub_data.y[sub_data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
def test(model):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = pred[mask] == data.y[mask] # Check against ground-truth labels.
accs.append(int(correct.sum()) / int(mask.sum())) # Derive ratio of correct predictions.
return accs
def run(model, epochs=5):
for epoch in range(1, epochs):
loss = train(model)
train_acc, val_acc, test_acc = test(model)
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
Training GNNs with torch_geometric.nn.aggr package¶
Define a GNN class¶
[4]:
import copy
import torch.nn.functional as F
from torch_geometric.nn import (
SAGEConv,
Aggregation,
MeanAggregation,
MaxAggregation,
SumAggregation,
StdAggregation,
VarAggregation,
MultiAggregation,
SoftmaxAggregation,
)
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, aggr='mean', aggr_kwargs=None):
super(GNN, self).__init__()
torch.manual_seed(12345)
if isinstance(aggr, list):
num_aggrs = len(aggr)
elif isinstance(aggr, str):
num_aggrs = 1
elif isinstance(aggr, MultiAggregation):
num_aggrs = len(aggr.aggrs)
elif isinstance(aggr, Aggregation):
num_aggrs = 1
else:
raise KeyError(f"Unknown aggr: {aggr}")
conv1_aggr, conv2_aggr = aggr, copy.deepcopy(aggr)
self.conv1 = SAGEConv([dataset.num_node_features, dataset.num_node_features],
hidden_channels,
aggr=conv1_aggr,
aggr_kwargs=aggr_kwargs)
self.conv2 = SAGEConv([hidden_channels, hidden_channels],
dataset.num_classes,
aggr=conv2_aggr,
aggr_kwargs=aggr_kwargs)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
Original interface with string type as aggregation argument¶
[5]:
model = GNN(16, aggr='mean')
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=mean)
(conv2): SAGEConv([16, 16], 3, aggr=mean)
)
Epoch: 001, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 002, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 003, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 004, Train: 0.6667, Val Acc: 0.5060, Test Acc: 0.5430
Use a single aggregation module as aggregation argument¶
[6]:
model = GNN(16, aggr=MeanAggregation())
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=MeanAggregation())
(conv2): SAGEConv([16, 16], 3, aggr=MeanAggregation())
)
Epoch: 001, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 002, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 003, Train: 0.3333, Val Acc: 0.3880, Test Acc: 0.4130
Epoch: 004, Train: 0.6667, Val Acc: 0.5060, Test Acc: 0.5430
Use a list of aggregation strings as aggregation argument¶
[7]:
model = GNN(16, aggr=['mean', 'max', 'sum', 'std', 'var'])
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=['mean', 'max', 'sum', 'std', 'var'])
(conv2): SAGEConv([16, 16], 3, aggr=['mean', 'max', 'sum', 'std', 'var'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940
Use a list of aggregation modules as aggregation argument¶
[8]:
model = GNN(16, aggr=[
MeanAggregation(),
MaxAggregation(),
SumAggregation(),
StdAggregation(),
VarAggregation(),
])
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
(conv2): SAGEConv([16, 16], 3, aggr=['MeanAggregation()', 'MaxAggregation()', 'SumAggregation()', 'StdAggregation()', 'VarAggregation()'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940
Use a list of mixed modules and strings as aggregation argument¶
[9]:
model = GNN(16, aggr=[
'mean',
MaxAggregation(),
'sum',
StdAggregation(),
'var',
])
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
(conv2): SAGEConv([16, 16], 3, aggr=['mean', 'MaxAggregation()', 'sum', 'StdAggregation()', 'var'])
)
Epoch: 001, Train: 0.5000, Val Acc: 0.3640, Test Acc: 0.3550
Epoch: 002, Train: 0.7833, Val Acc: 0.6120, Test Acc: 0.6160
Epoch: 003, Train: 0.8167, Val Acc: 0.5680, Test Acc: 0.5350
Epoch: 004, Train: 0.8667, Val Acc: 0.7120, Test Acc: 0.6940
Define multiple learnable aggregations with keyword arguments¶
[10]:
aggr = ['softmax', 'softmax', 'softmax']
aggrs_kwargs = [dict(t=0.1, learn=True),
dict(t=1, learn=True),
dict(t=10, learn=True)]
model = GNN(16, aggr=aggr, aggr_kwargs=dict(aggrs_kwargs=aggrs_kwargs))
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=['softmax', 'softmax', 'softmax'])
(conv2): SAGEConv([16, 16], 3, aggr=['softmax', 'softmax', 'softmax'])
)
Epoch: 001, Train: 0.8500, Val Acc: 0.6980, Test Acc: 0.7010
Epoch: 002, Train: 0.9333, Val Acc: 0.6420, Test Acc: 0.6600
Epoch: 003, Train: 0.7500, Val Acc: 0.6260, Test Acc: 0.6520
Epoch: 004, Train: 0.9333, Val Acc: 0.7580, Test Acc: 0.7430
Define multiple aggregations with MultiAggregation module¶
[11]:
aggr = MultiAggregation([SoftmaxAggregation(t=0.1, learn=True),
SoftmaxAggregation(t=1, learn=True),
SoftmaxAggregation(t=10, learn=True)])
model = GNN(16, aggr=aggr)
print(model)
run(model)
GNN(
(conv1): SAGEConv([500, 500], 16, aggr=MultiAggregation([
SoftmaxAggregation(learn=True),
SoftmaxAggregation(learn=True),
SoftmaxAggregation(learn=True)
], mode=cat))
(conv2): SAGEConv([16, 16], 3, aggr=MultiAggregation([
SoftmaxAggregation(learn=True),
SoftmaxAggregation(learn=True),
SoftmaxAggregation(learn=True)
], mode=cat))
)
Epoch: 001, Train: 0.8500, Val Acc: 0.6980, Test Acc: 0.7010
Epoch: 002, Train: 0.9333, Val Acc: 0.6420, Test Acc: 0.6600
Epoch: 003, Train: 0.7500, Val Acc: 0.6260, Test Acc: 0.6520
Epoch: 004, Train: 0.9333, Val Acc: 0.7580, Test Acc: 0.7430
Conclusion¶
In this tutorial, you have been presented with the new torch_geometric.nn.aggr package which provides a flexible interface to experiment with different aggregation functions with your message passing convolutions and unifies aggregation within GNNs across `MessagePassing <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py>`__ and global readouts. This
new abstraction also makes designing new type of aggregation function easier. Now, you can create your own aggregation function with the base Aggregation class:
class MyAggregation(Aggregation):
def __init__(self, ...):
...
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
...
Have fun!