[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize(h, color):
z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.xticks([])
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
plt.show()
1.11.0
2. Node Classification with Graph Neural Networks¶
Previous: Introduction: Hands-on Graph Neural Networks
This tutorial will teach you how to apply Graph Neural Networks (GNNs) to the task of node classification. Here, we are given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning).
To demonstrate, we make use of the Cora dataset, which is a citation network where nodes represent documents. Each node is described by a 1433-dimensional bag-of-words feature vector. Two documents are connected if there exists a citation link between them. The task is to infer the category of each document (7 in total).
This dataset was first introduced by Yang et al. (2016) as one of the datasets of the Planetoid benchmark suite. We again can make use PyTorch Geometric for an easy access to this dataset via `torch_geometric.datasets.Planetoid <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.Planetoid>`__:
[2]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('===========================================================================================================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
===========================================================================================================
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Training node label rate: 0.05
Has isolated nodes: False
Has self-loops: False
Is undirected: True
Overall, this dataset is quite similar to the previously used `KarateClub <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.KarateClub>`__ network. We can see that the Cora network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). This results in a training node label rate of only 5%.
In contrast to KarateClub, this graph holds the additional attributes val_mask and test_mask, which denotes which nodes should be used for validation and testing. Furthermore, we make use of data transformationsvia ``transform=NormalizeFeatures()``. Transforms can be used to modify your input data before inputting them into a neural network, e.g., for normalization or data
augmentation. Here, we row-normalize the bag-of-words input feature vectors.
We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
Training a Multi-layer Perception Network (MLP)¶
In theory, we should be able to infer the category of a document solely based on its content, i.e. its bag-of-words feature representation, without taking any relational information into account.
Let’s verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
[3]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
class MLP(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
torch.manual_seed(12345)
self.lin1 = Linear(dataset.num_features, hidden_channels)
self.lin2 = Linear(hidden_channels, dataset.num_classes)
def forward(self, x):
x = self.lin1(x)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return x
model = MLP(hidden_channels=16)
print(model)
MLP(
(lin1): Linear(in_features=1433, out_features=16, bias=True)
(lin2): Linear(in_features=16, out_features=7, bias=True)
)
Our MLP is defined by two linear layers and enhanced by ReLU non-linearity and dropout. Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (hidden_channels=16), while the second linear layer acts as a classifier that should map each low-dimensional node
embedding to one of the 7 classes.
Let’s train our simple MLP by following a similar procedure as described in the first part of this tutorial. We again make use of the cross entropy loss and Adam optimizer. This time, we also define a ``test`` function to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
[4]:
# from IPython.display import Javascript # Restrict height of output cell.
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = MLP(hidden_channels=16)
criterion = torch.nn.CrossEntropyLoss() # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Define optimizer.
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
for epoch in range(1, 201):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
Epoch: 001, Loss: 1.9609
Epoch: 002, Loss: 1.9558
Epoch: 003, Loss: 1.9485
Epoch: 004, Loss: 1.9413
Epoch: 005, Loss: 1.9297
Epoch: 006, Loss: 1.9266
Epoch: 007, Loss: 1.9124
Epoch: 008, Loss: 1.9024
Epoch: 009, Loss: 1.8985
Epoch: 010, Loss: 1.8849
Epoch: 011, Loss: 1.8753
Epoch: 012, Loss: 1.8684
Epoch: 013, Loss: 1.8483
Epoch: 014, Loss: 1.8465
Epoch: 015, Loss: 1.8115
Epoch: 016, Loss: 1.8143
Epoch: 017, Loss: 1.7900
Epoch: 018, Loss: 1.7985
Epoch: 019, Loss: 1.7639
Epoch: 020, Loss: 1.7383
Epoch: 021, Loss: 1.7229
Epoch: 022, Loss: 1.7452
Epoch: 023, Loss: 1.7000
Epoch: 024, Loss: 1.6855
Epoch: 025, Loss: 1.6691
Epoch: 026, Loss: 1.6479
Epoch: 027, Loss: 1.6284
Epoch: 028, Loss: 1.5960
Epoch: 029, Loss: 1.5562
Epoch: 030, Loss: 1.5396
Epoch: 031, Loss: 1.5488
Epoch: 032, Loss: 1.5224
Epoch: 033, Loss: 1.4665
Epoch: 034, Loss: 1.4885
Epoch: 035, Loss: 1.4399
Epoch: 036, Loss: 1.4035
Epoch: 037, Loss: 1.4265
Epoch: 038, Loss: 1.3742
Epoch: 039, Loss: 1.3742
Epoch: 040, Loss: 1.3391
Epoch: 041, Loss: 1.3089
Epoch: 042, Loss: 1.2701
Epoch: 043, Loss: 1.2487
Epoch: 044, Loss: 1.2711
Epoch: 045, Loss: 1.2219
Epoch: 046, Loss: 1.1674
Epoch: 047, Loss: 1.1749
Epoch: 048, Loss: 1.1392
Epoch: 049, Loss: 1.1493
Epoch: 050, Loss: 1.1060
Epoch: 051, Loss: 1.0598
Epoch: 052, Loss: 1.0722
Epoch: 053, Loss: 1.0351
Epoch: 054, Loss: 1.0588
Epoch: 055, Loss: 0.9974
Epoch: 056, Loss: 0.9684
Epoch: 057, Loss: 0.8953
Epoch: 058, Loss: 0.9257
Epoch: 059, Loss: 0.9519
Epoch: 060, Loss: 0.9173
Epoch: 061, Loss: 0.8916
Epoch: 062, Loss: 0.8910
Epoch: 063, Loss: 0.8902
Epoch: 064, Loss: 0.9588
Epoch: 065, Loss: 0.8639
Epoch: 066, Loss: 0.8241
Epoch: 067, Loss: 0.8078
Epoch: 068, Loss: 0.7704
Epoch: 069, Loss: 0.7910
Epoch: 070, Loss: 0.7889
Epoch: 071, Loss: 0.7019
Epoch: 072, Loss: 0.7265
Epoch: 073, Loss: 0.8248
Epoch: 074, Loss: 0.7720
Epoch: 075, Loss: 0.7473
Epoch: 076, Loss: 0.7014
Epoch: 077, Loss: 0.7245
Epoch: 078, Loss: 0.7311
Epoch: 079, Loss: 0.6861
Epoch: 080, Loss: 0.6899
Epoch: 081, Loss: 0.6589
Epoch: 082, Loss: 0.7592
Epoch: 083, Loss: 0.6846
Epoch: 084, Loss: 0.6654
Epoch: 085, Loss: 0.6847
Epoch: 086, Loss: 0.6332
Epoch: 087, Loss: 0.6185
Epoch: 088, Loss: 0.5642
Epoch: 089, Loss: 0.6276
Epoch: 090, Loss: 0.6685
Epoch: 091, Loss: 0.5877
Epoch: 092, Loss: 0.6406
Epoch: 093, Loss: 0.6497
Epoch: 094, Loss: 0.6774
Epoch: 095, Loss: 0.6194
Epoch: 096, Loss: 0.5723
Epoch: 097, Loss: 0.5686
Epoch: 098, Loss: 0.6565
Epoch: 099, Loss: 0.6382
Epoch: 100, Loss: 0.6106
Epoch: 101, Loss: 0.6437
Epoch: 102, Loss: 0.5263
Epoch: 103, Loss: 0.5579
Epoch: 104, Loss: 0.6602
Epoch: 105, Loss: 0.5510
Epoch: 106, Loss: 0.5922
Epoch: 107, Loss: 0.5349
Epoch: 108, Loss: 0.6588
Epoch: 109, Loss: 0.5956
Epoch: 110, Loss: 0.5888
Epoch: 111, Loss: 0.5614
Epoch: 112, Loss: 0.5674
Epoch: 113, Loss: 0.5975
Epoch: 114, Loss: 0.5377
Epoch: 115, Loss: 0.5275
Epoch: 116, Loss: 0.5749
Epoch: 117, Loss: 0.4727
Epoch: 118, Loss: 0.5000
Epoch: 119, Loss: 0.5141
Epoch: 120, Loss: 0.5277
Epoch: 121, Loss: 0.5901
Epoch: 122, Loss: 0.4717
Epoch: 123, Loss: 0.5386
Epoch: 124, Loss: 0.5428
Epoch: 125, Loss: 0.5836
Epoch: 126, Loss: 0.5231
Epoch: 127, Loss: 0.5148
Epoch: 128, Loss: 0.5199
Epoch: 129, Loss: 0.5273
Epoch: 130, Loss: 0.5053
Epoch: 131, Loss: 0.5465
Epoch: 132, Loss: 0.5200
Epoch: 133, Loss: 0.5447
Epoch: 134, Loss: 0.5609
Epoch: 135, Loss: 0.5323
Epoch: 136, Loss: 0.4876
Epoch: 137, Loss: 0.5709
Epoch: 138, Loss: 0.4879
Epoch: 139, Loss: 0.4570
Epoch: 140, Loss: 0.5623
Epoch: 141, Loss: 0.6071
Epoch: 142, Loss: 0.5546
Epoch: 143, Loss: 0.4775
Epoch: 144, Loss: 0.5089
Epoch: 145, Loss: 0.6208
Epoch: 146, Loss: 0.5192
Epoch: 147, Loss: 0.4429
Epoch: 148, Loss: 0.5048
Epoch: 149, Loss: 0.5230
Epoch: 150, Loss: 0.4934
Epoch: 151, Loss: 0.4385
Epoch: 152, Loss: 0.5256
Epoch: 153, Loss: 0.4857
Epoch: 154, Loss: 0.4702
Epoch: 155, Loss: 0.4916
Epoch: 156, Loss: 0.4609
Epoch: 157, Loss: 0.4786
Epoch: 158, Loss: 0.4301
Epoch: 159, Loss: 0.4984
Epoch: 160, Loss: 0.4170
Epoch: 161, Loss: 0.4495
Epoch: 162, Loss: 0.4629
Epoch: 163, Loss: 0.4615
Epoch: 164, Loss: 0.4384
Epoch: 165, Loss: 0.4843
Epoch: 166, Loss: 0.4879
Epoch: 167, Loss: 0.4943
Epoch: 168, Loss: 0.4831
Epoch: 169, Loss: 0.4456
Epoch: 170, Loss: 0.4678
Epoch: 171, Loss: 0.4779
Epoch: 172, Loss: 0.4993
Epoch: 173, Loss: 0.4267
Epoch: 174, Loss: 0.4000
Epoch: 175, Loss: 0.4422
Epoch: 176, Loss: 0.4319
Epoch: 177, Loss: 0.4658
Epoch: 178, Loss: 0.4786
Epoch: 179, Loss: 0.4253
Epoch: 180, Loss: 0.4808
Epoch: 181, Loss: 0.4401
Epoch: 182, Loss: 0.3968
Epoch: 183, Loss: 0.4517
Epoch: 184, Loss: 0.4749
Epoch: 185, Loss: 0.3948
Epoch: 186, Loss: 0.4846
Epoch: 187, Loss: 0.4696
Epoch: 188, Loss: 0.4537
Epoch: 189, Loss: 0.4547
Epoch: 190, Loss: 0.4602
Epoch: 191, Loss: 0.4548
Epoch: 192, Loss: 0.4516
Epoch: 193, Loss: 0.4540
Epoch: 194, Loss: 0.4303
Epoch: 195, Loss: 0.4404
Epoch: 196, Loss: 0.4452
Epoch: 197, Loss: 0.4560
Epoch: 198, Loss: 0.4428
Epoch: 199, Loss: 0.4487
Epoch: 200, Loss: 0.4827
After training the model, we can call the test function to see how well our model performs on unseen labels. Here, we are interested in the accuracy of the model, i.e., the ratio of correctly classified nodes:
[5]:
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.5740
As one can see, our MLP performs rather bad with only about 59% test accuracy. But why does the MLP do not perform better? The main reason for that is that this model suffers from heavy overfitting due to only having access to a small amount of training nodes, and therefore generalizes poorly to unseen node representations.
It also fails to incorporate an important bias into the model: Cited papers are very likely related to the category of a document. That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
Training a Graph Neural Network (GNN)¶
We can easily convert our MLP to a GNN by swapping the torch.nn.Linear layers with PyG’s GNN operators.
Following-up on the first part of this tutorial, we replace the linear layers by the `GCNConv <https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv>`__ module. To recap, the GCN layer (Kipf et al. (2017)) is defined as
where \(\mathbf{W}^{(\ell + 1)}\) denotes a trainable weight matrix of shape [num_output_features, num_input_features] and \(c_{w,v}\) refers to a fixed normalization coefficient for each edge. In contrast, a single Linear layer is defined as
which does not make use of neighboring node information.
[6]:
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(hidden_channels=16)
print(model)
GCN(
(conv1): GCNConv(1433, 16)
(conv2): GCNConv(16, 7)
)
Let’s visualize the node embeddings of our untrained GCN network. For visualization, we make use of TSNE to embed our 7-dimensional node embeddings onto a 2D plane.
[7]:
model = GCN(hidden_channels=16)
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
FutureWarning,
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
FutureWarning,
We certainly can do better by training our model. The training and testing procedure is once again the same, but this time we make use of the node features x and the graph connectivity edge_index as input to our GCN model.
[8]:
# from IPython.display import Javascript # Restrict height of output cell.
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
for epoch in range(1, 101):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
Epoch: 001, Loss: 1.9465
Epoch: 002, Loss: 1.9419
Epoch: 003, Loss: 1.9363
Epoch: 004, Loss: 1.9290
Epoch: 005, Loss: 1.9199
Epoch: 006, Loss: 1.9140
Epoch: 007, Loss: 1.9079
Epoch: 008, Loss: 1.8992
Epoch: 009, Loss: 1.8876
Epoch: 010, Loss: 1.8764
Epoch: 011, Loss: 1.8656
Epoch: 012, Loss: 1.8626
Epoch: 013, Loss: 1.8460
Epoch: 014, Loss: 1.8329
Epoch: 015, Loss: 1.8225
Epoch: 016, Loss: 1.8167
Epoch: 017, Loss: 1.7995
Epoch: 018, Loss: 1.7878
Epoch: 019, Loss: 1.7716
Epoch: 020, Loss: 1.7568
Epoch: 021, Loss: 1.7563
Epoch: 022, Loss: 1.7342
Epoch: 023, Loss: 1.7092
Epoch: 024, Loss: 1.7015
Epoch: 025, Loss: 1.6671
Epoch: 026, Loss: 1.6757
Epoch: 027, Loss: 1.6609
Epoch: 028, Loss: 1.6355
Epoch: 029, Loss: 1.6339
Epoch: 030, Loss: 1.6102
Epoch: 031, Loss: 1.5964
Epoch: 032, Loss: 1.5721
Epoch: 033, Loss: 1.5570
Epoch: 034, Loss: 1.5445
Epoch: 035, Loss: 1.5093
Epoch: 036, Loss: 1.4889
Epoch: 037, Loss: 1.4776
Epoch: 038, Loss: 1.4704
Epoch: 039, Loss: 1.4263
Epoch: 040, Loss: 1.3972
Epoch: 041, Loss: 1.3873
Epoch: 042, Loss: 1.3479
Epoch: 043, Loss: 1.3485
Epoch: 044, Loss: 1.3739
Epoch: 045, Loss: 1.3343
Epoch: 046, Loss: 1.3277
Epoch: 047, Loss: 1.2770
Epoch: 048, Loss: 1.2651
Epoch: 049, Loss: 1.2347
Epoch: 050, Loss: 1.2543
Epoch: 051, Loss: 1.1622
Epoch: 052, Loss: 1.1483
Epoch: 053, Loss: 1.1535
Epoch: 054, Loss: 1.1912
Epoch: 055, Loss: 1.0880
Epoch: 056, Loss: 1.1374
Epoch: 057, Loss: 1.0657
Epoch: 058, Loss: 1.0748
Epoch: 059, Loss: 1.0654
Epoch: 060, Loss: 1.0201
Epoch: 061, Loss: 0.9967
Epoch: 062, Loss: 1.0499
Epoch: 063, Loss: 1.0116
Epoch: 064, Loss: 0.9945
Epoch: 065, Loss: 0.9499
Epoch: 066, Loss: 0.9465
Epoch: 067, Loss: 0.9633
Epoch: 068, Loss: 0.9137
Epoch: 069, Loss: 0.9168
Epoch: 070, Loss: 0.8818
Epoch: 071, Loss: 0.8984
Epoch: 072, Loss: 0.8301
Epoch: 073, Loss: 0.8664
Epoch: 074, Loss: 0.8560
Epoch: 075, Loss: 0.8457
Epoch: 076, Loss: 0.8306
Epoch: 077, Loss: 0.8333
Epoch: 078, Loss: 0.8155
Epoch: 079, Loss: 0.7878
Epoch: 080, Loss: 0.8277
Epoch: 081, Loss: 0.7880
Epoch: 082, Loss: 0.7829
Epoch: 083, Loss: 0.7829
Epoch: 084, Loss: 0.7633
Epoch: 085, Loss: 0.7862
Epoch: 086, Loss: 0.7497
Epoch: 087, Loss: 0.7760
Epoch: 088, Loss: 0.7419
Epoch: 089, Loss: 0.6595
Epoch: 090, Loss: 0.6746
Epoch: 091, Loss: 0.7432
Epoch: 092, Loss: 0.6609
Epoch: 093, Loss: 0.6607
Epoch: 094, Loss: 0.6884
Epoch: 095, Loss: 0.6596
Epoch: 096, Loss: 0.6456
Epoch: 097, Loss: 0.6383
Epoch: 098, Loss: 0.7031
Epoch: 099, Loss: 0.6437
Epoch: 100, Loss: 0.6375
After training the model, we can check its test accuracy:
[9]:
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.8110
There it is! By simply swapping the linear layers with GNN layers, we can reach 81.5% of test accuracy! This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
[10]:
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
FutureWarning,
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
FutureWarning,
Conclusion¶
In this chapter, you have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model’s performance. In the next section, we will look into how GNNs can be used for the task of graph classification.
(Optional) Exercises¶
To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The
Coradataset provides a validation node set asdata.val_mask, but we haven’t used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to 82% accuracy.How does
GCNbehave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?You can try to use different GNN layers to see how model performance changes. What happens if you swap out all
GCNConvinstances with`GATConv<https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv>`__ layers that make use of attention? Try to write a 2-layerGATmodel that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses adropoutratio of0.6inside and outside eachGATConvcall, and uses ahidden_channelsdimensions of8per head.
[11]:
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, heads):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GATConv(in_channels=dataset.num_features, out_channels=hidden_channels, heads=heads,) # TODO
self.conv2 = GATConv(in_channels=hidden_channels*heads, out_channels=dataset.num_classes, heads=1) # TODO
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT(hidden_channels=8, heads=8)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test(mask):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct = pred[mask] == data.y[mask] # Check against ground-truth labels.
acc = int(correct.sum()) / int(mask.sum()) # Derive ratio of correct predictions.
return acc
for epoch in range(1, 201):
loss = train()
val_acc = test(data.val_mask)
test_acc = test(data.test_mask)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
GAT(
(conv1): GATConv(1433, 8, heads=8)
(conv2): GATConv(64, 7, heads=1)
)
Epoch: 001, Loss: 1.9438, Val: 0.2900, Test: 0.3510
Epoch: 002, Loss: 1.9346, Val: 0.4420, Test: 0.4780
Epoch: 003, Loss: 1.9258, Val: 0.5360, Test: 0.5710
Epoch: 004, Loss: 1.9147, Val: 0.6820, Test: 0.6910
Epoch: 005, Loss: 1.9020, Val: 0.7580, Test: 0.7700
Epoch: 006, Loss: 1.8962, Val: 0.7940, Test: 0.8080
Epoch: 007, Loss: 1.8831, Val: 0.7960, Test: 0.8070
Epoch: 008, Loss: 1.8730, Val: 0.7840, Test: 0.8040
Epoch: 009, Loss: 1.8633, Val: 0.7820, Test: 0.7980
Epoch: 010, Loss: 1.8394, Val: 0.7900, Test: 0.7930
Epoch: 011, Loss: 1.8355, Val: 0.7760, Test: 0.7900
Epoch: 012, Loss: 1.8170, Val: 0.7820, Test: 0.7810
Epoch: 013, Loss: 1.8091, Val: 0.7840, Test: 0.7830
Epoch: 014, Loss: 1.7960, Val: 0.7840, Test: 0.7840
Epoch: 015, Loss: 1.7795, Val: 0.7820, Test: 0.7780
Epoch: 016, Loss: 1.7599, Val: 0.7800, Test: 0.7780
Epoch: 017, Loss: 1.7490, Val: 0.7860, Test: 0.7810
Epoch: 018, Loss: 1.7229, Val: 0.7920, Test: 0.7890
Epoch: 019, Loss: 1.7195, Val: 0.7900, Test: 0.7990
Epoch: 020, Loss: 1.6972, Val: 0.7920, Test: 0.8090
Epoch: 021, Loss: 1.6714, Val: 0.7980, Test: 0.8090
Epoch: 022, Loss: 1.6731, Val: 0.8020, Test: 0.8130
Epoch: 023, Loss: 1.6381, Val: 0.8020, Test: 0.8220
Epoch: 024, Loss: 1.6291, Val: 0.8060, Test: 0.8220
Epoch: 025, Loss: 1.6115, Val: 0.8040, Test: 0.8220
Epoch: 026, Loss: 1.5753, Val: 0.8080, Test: 0.8220
Epoch: 027, Loss: 1.5604, Val: 0.8000, Test: 0.8210
Epoch: 028, Loss: 1.5455, Val: 0.7960, Test: 0.8180
Epoch: 029, Loss: 1.5080, Val: 0.7940, Test: 0.8170
Epoch: 030, Loss: 1.4835, Val: 0.7960, Test: 0.8160
Epoch: 031, Loss: 1.4696, Val: 0.7940, Test: 0.8080
Epoch: 032, Loss: 1.4851, Val: 0.7940, Test: 0.8080
Epoch: 033, Loss: 1.4255, Val: 0.8020, Test: 0.8110
Epoch: 034, Loss: 1.4121, Val: 0.8020, Test: 0.8090
Epoch: 035, Loss: 1.3797, Val: 0.7980, Test: 0.8090
Epoch: 036, Loss: 1.3539, Val: 0.7940, Test: 0.8070
Epoch: 037, Loss: 1.3523, Val: 0.7940, Test: 0.8040
Epoch: 038, Loss: 1.3044, Val: 0.7940, Test: 0.8030
Epoch: 039, Loss: 1.2950, Val: 0.7960, Test: 0.8030
Epoch: 040, Loss: 1.2726, Val: 0.8000, Test: 0.8080
Epoch: 041, Loss: 1.2666, Val: 0.8000, Test: 0.8070
Epoch: 042, Loss: 1.2070, Val: 0.8040, Test: 0.8090
Epoch: 043, Loss: 1.1997, Val: 0.8040, Test: 0.8080
Epoch: 044, Loss: 1.1373, Val: 0.8040, Test: 0.8060
Epoch: 045, Loss: 1.1208, Val: 0.8020, Test: 0.8020
Epoch: 046, Loss: 1.1086, Val: 0.8020, Test: 0.8000
Epoch: 047, Loss: 1.0837, Val: 0.8000, Test: 0.8010
Epoch: 048, Loss: 1.0452, Val: 0.8000, Test: 0.8010
Epoch: 049, Loss: 1.0323, Val: 0.8000, Test: 0.8000
Epoch: 050, Loss: 1.0636, Val: 0.8000, Test: 0.7980
Epoch: 051, Loss: 0.9951, Val: 0.8000, Test: 0.8010
Epoch: 052, Loss: 0.9425, Val: 0.8020, Test: 0.8020
Epoch: 053, Loss: 0.9659, Val: 0.8020, Test: 0.8040
Epoch: 054, Loss: 0.8752, Val: 0.7960, Test: 0.8060
Epoch: 055, Loss: 0.9092, Val: 0.7960, Test: 0.8050
Epoch: 056, Loss: 0.8657, Val: 0.7980, Test: 0.8040
Epoch: 057, Loss: 0.8549, Val: 0.8020, Test: 0.8020
Epoch: 058, Loss: 0.8361, Val: 0.8020, Test: 0.8010
Epoch: 059, Loss: 0.8102, Val: 0.8000, Test: 0.8010
Epoch: 060, Loss: 0.7850, Val: 0.8000, Test: 0.8000
Epoch: 061, Loss: 0.7942, Val: 0.8000, Test: 0.8000
Epoch: 062, Loss: 0.7536, Val: 0.7960, Test: 0.8010
Epoch: 063, Loss: 0.7126, Val: 0.7940, Test: 0.7950
Epoch: 064, Loss: 0.7691, Val: 0.7980, Test: 0.7940
Epoch: 065, Loss: 0.7238, Val: 0.7940, Test: 0.7960
Epoch: 066, Loss: 0.7133, Val: 0.7940, Test: 0.7970
Epoch: 067, Loss: 0.6969, Val: 0.7960, Test: 0.7970
Epoch: 068, Loss: 0.6833, Val: 0.7960, Test: 0.7980
Epoch: 069, Loss: 0.6852, Val: 0.7940, Test: 0.8000
Epoch: 070, Loss: 0.6366, Val: 0.7940, Test: 0.8020
Epoch: 071, Loss: 0.6740, Val: 0.7960, Test: 0.7990
Epoch: 072, Loss: 0.6289, Val: 0.7940, Test: 0.8000
Epoch: 073, Loss: 0.5892, Val: 0.7940, Test: 0.8020
Epoch: 074, Loss: 0.5902, Val: 0.7940, Test: 0.8050
Epoch: 075, Loss: 0.5980, Val: 0.7880, Test: 0.8000
Epoch: 076, Loss: 0.5735, Val: 0.7880, Test: 0.7980
Epoch: 077, Loss: 0.5576, Val: 0.7900, Test: 0.7940
Epoch: 078, Loss: 0.5164, Val: 0.7900, Test: 0.7900
Epoch: 079, Loss: 0.5654, Val: 0.7920, Test: 0.7890
Epoch: 080, Loss: 0.5270, Val: 0.7920, Test: 0.7900
Epoch: 081, Loss: 0.5703, Val: 0.7920, Test: 0.7890
Epoch: 082, Loss: 0.5218, Val: 0.7920, Test: 0.7920
Epoch: 083, Loss: 0.5286, Val: 0.7920, Test: 0.7960
Epoch: 084, Loss: 0.4852, Val: 0.7920, Test: 0.7970
Epoch: 085, Loss: 0.5530, Val: 0.7900, Test: 0.8000
Epoch: 086, Loss: 0.4803, Val: 0.7940, Test: 0.7970
Epoch: 087, Loss: 0.5127, Val: 0.7960, Test: 0.7990
Epoch: 088, Loss: 0.4354, Val: 0.7980, Test: 0.7990
Epoch: 089, Loss: 0.4546, Val: 0.7940, Test: 0.7990
Epoch: 090, Loss: 0.4660, Val: 0.7900, Test: 0.7970
Epoch: 091, Loss: 0.4545, Val: 0.7900, Test: 0.7970
Epoch: 092, Loss: 0.4153, Val: 0.7900, Test: 0.7950
Epoch: 093, Loss: 0.4121, Val: 0.7920, Test: 0.7910
Epoch: 094, Loss: 0.4513, Val: 0.7880, Test: 0.7890
Epoch: 095, Loss: 0.4283, Val: 0.7880, Test: 0.7850
Epoch: 096, Loss: 0.4737, Val: 0.7820, Test: 0.7830
Epoch: 097, Loss: 0.4406, Val: 0.7800, Test: 0.7810
Epoch: 098, Loss: 0.4497, Val: 0.7820, Test: 0.7820
Epoch: 099, Loss: 0.4351, Val: 0.7820, Test: 0.7870
Epoch: 100, Loss: 0.3983, Val: 0.7840, Test: 0.7830
Epoch: 101, Loss: 0.4242, Val: 0.7820, Test: 0.7850
Epoch: 102, Loss: 0.4349, Val: 0.7840, Test: 0.7850
Epoch: 103, Loss: 0.4158, Val: 0.7860, Test: 0.7860
Epoch: 104, Loss: 0.4362, Val: 0.7900, Test: 0.7900
Epoch: 105, Loss: 0.3947, Val: 0.7900, Test: 0.7880
Epoch: 106, Loss: 0.4101, Val: 0.7880, Test: 0.7910
Epoch: 107, Loss: 0.3671, Val: 0.7880, Test: 0.7930
Epoch: 108, Loss: 0.3361, Val: 0.7900, Test: 0.7930
Epoch: 109, Loss: 0.4105, Val: 0.7900, Test: 0.7910
Epoch: 110, Loss: 0.3741, Val: 0.7880, Test: 0.7930
Epoch: 111, Loss: 0.3524, Val: 0.7920, Test: 0.7900
Epoch: 112, Loss: 0.4059, Val: 0.7900, Test: 0.7870
Epoch: 113, Loss: 0.4113, Val: 0.7860, Test: 0.7890
Epoch: 114, Loss: 0.3838, Val: 0.7860, Test: 0.7920
Epoch: 115, Loss: 0.3759, Val: 0.7860, Test: 0.7890
Epoch: 116, Loss: 0.3638, Val: 0.7880, Test: 0.7920
Epoch: 117, Loss: 0.3622, Val: 0.7900, Test: 0.7940
Epoch: 118, Loss: 0.3510, Val: 0.7900, Test: 0.7930
Epoch: 119, Loss: 0.3935, Val: 0.7920, Test: 0.7920
Epoch: 120, Loss: 0.3665, Val: 0.7900, Test: 0.7900
Epoch: 121, Loss: 0.3999, Val: 0.7900, Test: 0.7890
Epoch: 122, Loss: 0.3409, Val: 0.7860, Test: 0.7940
Epoch: 123, Loss: 0.3486, Val: 0.7840, Test: 0.7950
Epoch: 124, Loss: 0.3968, Val: 0.7860, Test: 0.7930
Epoch: 125, Loss: 0.3143, Val: 0.7920, Test: 0.7940
Epoch: 126, Loss: 0.3322, Val: 0.7880, Test: 0.7930
Epoch: 127, Loss: 0.3189, Val: 0.7880, Test: 0.7950
Epoch: 128, Loss: 0.3637, Val: 0.7880, Test: 0.7950
Epoch: 129, Loss: 0.3584, Val: 0.7940, Test: 0.7950
Epoch: 130, Loss: 0.2702, Val: 0.8000, Test: 0.7950
Epoch: 131, Loss: 0.3392, Val: 0.7980, Test: 0.7920
Epoch: 132, Loss: 0.3290, Val: 0.7940, Test: 0.7930
Epoch: 133, Loss: 0.3390, Val: 0.7960, Test: 0.7940
Epoch: 134, Loss: 0.3138, Val: 0.7980, Test: 0.7950
Epoch: 135, Loss: 0.3293, Val: 0.7960, Test: 0.7960
Epoch: 136, Loss: 0.3340, Val: 0.7940, Test: 0.7930
Epoch: 137, Loss: 0.3592, Val: 0.7940, Test: 0.7920
Epoch: 138, Loss: 0.3079, Val: 0.7900, Test: 0.7910
Epoch: 139, Loss: 0.2867, Val: 0.7900, Test: 0.7880
Epoch: 140, Loss: 0.3322, Val: 0.7820, Test: 0.7810
Epoch: 141, Loss: 0.2991, Val: 0.7840, Test: 0.7770
Epoch: 142, Loss: 0.3162, Val: 0.7860, Test: 0.7750
Epoch: 143, Loss: 0.3420, Val: 0.7880, Test: 0.7720
Epoch: 144, Loss: 0.3244, Val: 0.7860, Test: 0.7740
Epoch: 145, Loss: 0.3210, Val: 0.7840, Test: 0.7750
Epoch: 146, Loss: 0.3395, Val: 0.7840, Test: 0.7820
Epoch: 147, Loss: 0.2587, Val: 0.7800, Test: 0.7800
Epoch: 148, Loss: 0.3193, Val: 0.7820, Test: 0.7830
Epoch: 149, Loss: 0.2918, Val: 0.7900, Test: 0.7880
Epoch: 150, Loss: 0.2832, Val: 0.7960, Test: 0.7870
Epoch: 151, Loss: 0.2953, Val: 0.7920, Test: 0.7910
Epoch: 152, Loss: 0.2941, Val: 0.7960, Test: 0.7930
Epoch: 153, Loss: 0.2611, Val: 0.7960, Test: 0.7970
Epoch: 154, Loss: 0.3169, Val: 0.7920, Test: 0.7920
Epoch: 155, Loss: 0.2715, Val: 0.7940, Test: 0.7940
Epoch: 156, Loss: 0.3041, Val: 0.7940, Test: 0.7950
Epoch: 157, Loss: 0.2705, Val: 0.7940, Test: 0.7980
Epoch: 158, Loss: 0.2728, Val: 0.7920, Test: 0.7980
Epoch: 159, Loss: 0.2477, Val: 0.7920, Test: 0.7980
Epoch: 160, Loss: 0.2964, Val: 0.7940, Test: 0.7940
Epoch: 161, Loss: 0.3096, Val: 0.7940, Test: 0.7910
Epoch: 162, Loss: 0.2729, Val: 0.7920, Test: 0.7860
Epoch: 163, Loss: 0.2814, Val: 0.7880, Test: 0.7830
Epoch: 164, Loss: 0.3201, Val: 0.7920, Test: 0.7790
Epoch: 165, Loss: 0.2720, Val: 0.7960, Test: 0.7750
Epoch: 166, Loss: 0.2687, Val: 0.7940, Test: 0.7780
Epoch: 167, Loss: 0.2735, Val: 0.7940, Test: 0.7800
Epoch: 168, Loss: 0.3116, Val: 0.7940, Test: 0.7860
Epoch: 169, Loss: 0.2686, Val: 0.7940, Test: 0.7840
Epoch: 170, Loss: 0.2784, Val: 0.7960, Test: 0.7930
Epoch: 171, Loss: 0.2805, Val: 0.7980, Test: 0.7910
Epoch: 172, Loss: 0.2421, Val: 0.8020, Test: 0.7910
Epoch: 173, Loss: 0.2504, Val: 0.8000, Test: 0.7910
Epoch: 174, Loss: 0.2660, Val: 0.8000, Test: 0.7890
Epoch: 175, Loss: 0.2919, Val: 0.7960, Test: 0.7880
Epoch: 176, Loss: 0.2753, Val: 0.7960, Test: 0.7830
Epoch: 177, Loss: 0.2646, Val: 0.7940, Test: 0.7840
Epoch: 178, Loss: 0.3110, Val: 0.7820, Test: 0.7830
Epoch: 179, Loss: 0.2319, Val: 0.7760, Test: 0.7720
Epoch: 180, Loss: 0.2857, Val: 0.7760, Test: 0.7690
Epoch: 181, Loss: 0.2903, Val: 0.7740, Test: 0.7710
Epoch: 182, Loss: 0.2559, Val: 0.7760, Test: 0.7740
Epoch: 183, Loss: 0.2490, Val: 0.7780, Test: 0.7790
Epoch: 184, Loss: 0.2599, Val: 0.7840, Test: 0.7880
Epoch: 185, Loss: 0.2934, Val: 0.7880, Test: 0.7890
Epoch: 186, Loss: 0.2629, Val: 0.7940, Test: 0.7930
Epoch: 187, Loss: 0.2701, Val: 0.7960, Test: 0.7920
Epoch: 188, Loss: 0.2640, Val: 0.7960, Test: 0.7930
Epoch: 189, Loss: 0.2517, Val: 0.7960, Test: 0.7920
Epoch: 190, Loss: 0.2489, Val: 0.7980, Test: 0.7920
Epoch: 191, Loss: 0.2835, Val: 0.7980, Test: 0.7930
Epoch: 192, Loss: 0.3014, Val: 0.7960, Test: 0.7850
Epoch: 193, Loss: 0.2464, Val: 0.8040, Test: 0.7820
Epoch: 194, Loss: 0.2620, Val: 0.7900, Test: 0.7850
Epoch: 195, Loss: 0.2469, Val: 0.7840, Test: 0.7800
Epoch: 196, Loss: 0.2690, Val: 0.7860, Test: 0.7790
Epoch: 197, Loss: 0.2607, Val: 0.7860, Test: 0.7800
Epoch: 198, Loss: 0.2568, Val: 0.7860, Test: 0.7820
Epoch: 199, Loss: 0.2462, Val: 0.7880, Test: 0.7830
Epoch: 200, Loss: 0.2776, Val: 0.7880, Test: 0.7830
[12]:
test_acc = test(data.test_mask)
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.7830
[13]:
from torch_geometric.nn import GATv2Conv
class GAT_v2(torch.nn.Module):
def __init__(self, hidden_channels, heads):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GATv2Conv(in_channels=dataset.num_features, out_channels=hidden_channels, heads=heads,) # TODO
self.conv2 = GATv2Conv(in_channels=hidden_channels*heads, out_channels=dataset.num_classes, heads=1) # TODO
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT_v2(hidden_channels=8, heads=8)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test(mask):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct = pred[mask] == data.y[mask] # Check against ground-truth labels.
acc = int(correct.sum()) / int(mask.sum()) # Derive ratio of correct predictions.
return acc
for epoch in range(1, 201):
loss = train()
val_acc = test(data.val_mask)
test_acc = test(data.test_mask)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
GAT_v2(
(conv1): GATv2Conv(1433, 8, heads=8)
(conv2): GATv2Conv(64, 7, heads=1)
)
Epoch: 001, Loss: 1.9472, Val: 0.1220, Test: 0.1300
Epoch: 002, Loss: 1.9381, Val: 0.1280, Test: 0.1370
Epoch: 003, Loss: 1.9234, Val: 0.4380, Test: 0.4450
Epoch: 004, Loss: 1.9143, Val: 0.4820, Test: 0.5370
Epoch: 005, Loss: 1.9025, Val: 0.3320, Test: 0.3730
Epoch: 006, Loss: 1.8954, Val: 0.3840, Test: 0.4250
Epoch: 007, Loss: 1.8826, Val: 0.4780, Test: 0.5190
Epoch: 008, Loss: 1.8754, Val: 0.5820, Test: 0.6150
Epoch: 009, Loss: 1.8618, Val: 0.6600, Test: 0.6700
Epoch: 010, Loss: 1.8585, Val: 0.7060, Test: 0.7110
Epoch: 011, Loss: 1.8381, Val: 0.7360, Test: 0.7420
Epoch: 012, Loss: 1.8346, Val: 0.7540, Test: 0.7440
Epoch: 013, Loss: 1.8177, Val: 0.7380, Test: 0.7350
Epoch: 014, Loss: 1.8000, Val: 0.7280, Test: 0.7190
Epoch: 015, Loss: 1.7972, Val: 0.7200, Test: 0.7120
Epoch: 016, Loss: 1.7728, Val: 0.7220, Test: 0.7140
Epoch: 017, Loss: 1.7559, Val: 0.7260, Test: 0.7220
Epoch: 018, Loss: 1.7445, Val: 0.7320, Test: 0.7300
Epoch: 019, Loss: 1.7293, Val: 0.7400, Test: 0.7350
Epoch: 020, Loss: 1.7291, Val: 0.7460, Test: 0.7460
Epoch: 021, Loss: 1.7004, Val: 0.7520, Test: 0.7500
Epoch: 022, Loss: 1.6808, Val: 0.7660, Test: 0.7590
Epoch: 023, Loss: 1.6744, Val: 0.7780, Test: 0.7610
Epoch: 024, Loss: 1.6331, Val: 0.7760, Test: 0.7690
Epoch: 025, Loss: 1.6354, Val: 0.7720, Test: 0.7660
Epoch: 026, Loss: 1.6235, Val: 0.7800, Test: 0.7700
Epoch: 027, Loss: 1.5871, Val: 0.7900, Test: 0.7700
Epoch: 028, Loss: 1.5617, Val: 0.7960, Test: 0.7740
Epoch: 029, Loss: 1.5511, Val: 0.7960, Test: 0.7790
Epoch: 030, Loss: 1.5275, Val: 0.7960, Test: 0.7860
Epoch: 031, Loss: 1.5415, Val: 0.7920, Test: 0.7900
Epoch: 032, Loss: 1.4719, Val: 0.7940, Test: 0.7980
Epoch: 033, Loss: 1.4659, Val: 0.7920, Test: 0.8030
Epoch: 034, Loss: 1.4580, Val: 0.7840, Test: 0.8030
Epoch: 035, Loss: 1.4258, Val: 0.7900, Test: 0.8010
Epoch: 036, Loss: 1.4077, Val: 0.7860, Test: 0.7990
Epoch: 037, Loss: 1.3707, Val: 0.7860, Test: 0.7970
Epoch: 038, Loss: 1.3294, Val: 0.7860, Test: 0.7960
Epoch: 039, Loss: 1.3085, Val: 0.7860, Test: 0.7940
Epoch: 040, Loss: 1.3085, Val: 0.7900, Test: 0.7920
Epoch: 041, Loss: 1.2793, Val: 0.7940, Test: 0.8010
Epoch: 042, Loss: 1.2449, Val: 0.8000, Test: 0.8000
Epoch: 043, Loss: 1.2524, Val: 0.7980, Test: 0.8040
Epoch: 044, Loss: 1.2002, Val: 0.7940, Test: 0.8070
Epoch: 045, Loss: 1.1785, Val: 0.7960, Test: 0.8080
Epoch: 046, Loss: 1.1478, Val: 0.7900, Test: 0.8070
Epoch: 047, Loss: 1.1369, Val: 0.7880, Test: 0.8080
Epoch: 048, Loss: 1.1039, Val: 0.7900, Test: 0.7990
Epoch: 049, Loss: 1.1030, Val: 0.7920, Test: 0.7950
Epoch: 050, Loss: 1.0686, Val: 0.7900, Test: 0.7940
Epoch: 051, Loss: 1.0383, Val: 0.7920, Test: 0.7940
Epoch: 052, Loss: 1.0288, Val: 0.7900, Test: 0.7950
Epoch: 053, Loss: 0.9649, Val: 0.7940, Test: 0.7940
Epoch: 054, Loss: 0.9623, Val: 0.7920, Test: 0.7980
Epoch: 055, Loss: 0.9680, Val: 0.7880, Test: 0.7970
Epoch: 056, Loss: 0.9343, Val: 0.7860, Test: 0.8000
Epoch: 057, Loss: 0.8953, Val: 0.7920, Test: 0.8050
Epoch: 058, Loss: 0.8795, Val: 0.7960, Test: 0.8090
Epoch: 059, Loss: 0.8758, Val: 0.7940, Test: 0.8070
Epoch: 060, Loss: 0.8783, Val: 0.7980, Test: 0.8100
Epoch: 061, Loss: 0.8738, Val: 0.8040, Test: 0.8090
Epoch: 062, Loss: 0.8105, Val: 0.8020, Test: 0.8090
Epoch: 063, Loss: 0.8277, Val: 0.8040, Test: 0.8040
Epoch: 064, Loss: 0.7839, Val: 0.8060, Test: 0.8030
Epoch: 065, Loss: 0.7221, Val: 0.8060, Test: 0.7990
Epoch: 066, Loss: 0.7639, Val: 0.7980, Test: 0.7980
Epoch: 067, Loss: 0.6967, Val: 0.7960, Test: 0.7910
Epoch: 068, Loss: 0.6948, Val: 0.7960, Test: 0.7880
Epoch: 069, Loss: 0.6691, Val: 0.7940, Test: 0.7900
Epoch: 070, Loss: 0.7316, Val: 0.8000, Test: 0.7920
Epoch: 071, Loss: 0.6764, Val: 0.8000, Test: 0.7960
Epoch: 072, Loss: 0.6630, Val: 0.8040, Test: 0.8020
Epoch: 073, Loss: 0.6540, Val: 0.8080, Test: 0.8130
Epoch: 074, Loss: 0.6433, Val: 0.8120, Test: 0.8120
Epoch: 075, Loss: 0.6229, Val: 0.8120, Test: 0.8070
Epoch: 076, Loss: 0.6312, Val: 0.8060, Test: 0.8040
Epoch: 077, Loss: 0.6430, Val: 0.8080, Test: 0.8020
Epoch: 078, Loss: 0.5752, Val: 0.8000, Test: 0.7990
Epoch: 079, Loss: 0.6435, Val: 0.7980, Test: 0.7980
Epoch: 080, Loss: 0.5409, Val: 0.7980, Test: 0.7950
Epoch: 081, Loss: 0.5735, Val: 0.8000, Test: 0.7950
Epoch: 082, Loss: 0.5712, Val: 0.7960, Test: 0.7910
Epoch: 083, Loss: 0.5645, Val: 0.7940, Test: 0.7870
Epoch: 084, Loss: 0.5142, Val: 0.7980, Test: 0.7840
Epoch: 085, Loss: 0.5235, Val: 0.8040, Test: 0.7940
Epoch: 086, Loss: 0.5485, Val: 0.8020, Test: 0.8010
Epoch: 087, Loss: 0.5068, Val: 0.8020, Test: 0.8000
Epoch: 088, Loss: 0.5075, Val: 0.8020, Test: 0.8030
Epoch: 089, Loss: 0.4967, Val: 0.8040, Test: 0.8020
Epoch: 090, Loss: 0.4808, Val: 0.8040, Test: 0.8050
Epoch: 091, Loss: 0.4404, Val: 0.8000, Test: 0.8060
Epoch: 092, Loss: 0.4921, Val: 0.7980, Test: 0.8050
Epoch: 093, Loss: 0.5197, Val: 0.7960, Test: 0.8030
Epoch: 094, Loss: 0.4774, Val: 0.8000, Test: 0.8090
Epoch: 095, Loss: 0.4975, Val: 0.7980, Test: 0.8080
Epoch: 096, Loss: 0.4504, Val: 0.7980, Test: 0.8090
Epoch: 097, Loss: 0.4226, Val: 0.7960, Test: 0.8080
Epoch: 098, Loss: 0.4008, Val: 0.7960, Test: 0.8060
Epoch: 099, Loss: 0.4363, Val: 0.7960, Test: 0.8050
Epoch: 100, Loss: 0.4269, Val: 0.8000, Test: 0.8040
Epoch: 101, Loss: 0.4105, Val: 0.7960, Test: 0.8040
Epoch: 102, Loss: 0.4178, Val: 0.7980, Test: 0.8060
Epoch: 103, Loss: 0.3883, Val: 0.8040, Test: 0.8070
Epoch: 104, Loss: 0.4228, Val: 0.8060, Test: 0.8100
Epoch: 105, Loss: 0.4246, Val: 0.7980, Test: 0.8060
Epoch: 106, Loss: 0.4173, Val: 0.8020, Test: 0.8050
Epoch: 107, Loss: 0.4335, Val: 0.8040, Test: 0.8070
Epoch: 108, Loss: 0.3500, Val: 0.8080, Test: 0.8030
Epoch: 109, Loss: 0.4456, Val: 0.8020, Test: 0.8010
Epoch: 110, Loss: 0.3970, Val: 0.8000, Test: 0.8020
Epoch: 111, Loss: 0.4349, Val: 0.7960, Test: 0.8030
Epoch: 112, Loss: 0.3850, Val: 0.7920, Test: 0.8030
Epoch: 113, Loss: 0.3786, Val: 0.7880, Test: 0.7970
Epoch: 114, Loss: 0.3994, Val: 0.7880, Test: 0.7980
Epoch: 115, Loss: 0.3965, Val: 0.7880, Test: 0.8040
Epoch: 116, Loss: 0.3805, Val: 0.7920, Test: 0.8040
Epoch: 117, Loss: 0.3554, Val: 0.7920, Test: 0.8080
Epoch: 118, Loss: 0.3712, Val: 0.7940, Test: 0.8020
Epoch: 119, Loss: 0.3310, Val: 0.8060, Test: 0.8050
Epoch: 120, Loss: 0.3296, Val: 0.8000, Test: 0.8010
Epoch: 121, Loss: 0.3474, Val: 0.8080, Test: 0.8020
Epoch: 122, Loss: 0.3399, Val: 0.8040, Test: 0.8010
Epoch: 123, Loss: 0.3353, Val: 0.8060, Test: 0.8010
Epoch: 124, Loss: 0.3393, Val: 0.8040, Test: 0.8050
Epoch: 125, Loss: 0.3357, Val: 0.7980, Test: 0.8070
Epoch: 126, Loss: 0.4175, Val: 0.7980, Test: 0.8060
Epoch: 127, Loss: 0.3438, Val: 0.7960, Test: 0.8080
Epoch: 128, Loss: 0.3358, Val: 0.7900, Test: 0.8050
Epoch: 129, Loss: 0.3292, Val: 0.7980, Test: 0.8060
Epoch: 130, Loss: 0.3711, Val: 0.7940, Test: 0.8050
Epoch: 131, Loss: 0.3156, Val: 0.7920, Test: 0.8100
Epoch: 132, Loss: 0.3652, Val: 0.7940, Test: 0.8160
Epoch: 133, Loss: 0.3209, Val: 0.8020, Test: 0.8210
Epoch: 134, Loss: 0.3396, Val: 0.8060, Test: 0.8170
Epoch: 135, Loss: 0.3317, Val: 0.8120, Test: 0.8170
Epoch: 136, Loss: 0.3441, Val: 0.8100, Test: 0.8180
Epoch: 137, Loss: 0.3631, Val: 0.8000, Test: 0.8150
Epoch: 138, Loss: 0.3062, Val: 0.7960, Test: 0.8150
Epoch: 139, Loss: 0.3003, Val: 0.7960, Test: 0.8170
Epoch: 140, Loss: 0.3097, Val: 0.8080, Test: 0.8150
Epoch: 141, Loss: 0.3198, Val: 0.8040, Test: 0.8110
Epoch: 142, Loss: 0.2911, Val: 0.8060, Test: 0.8070
Epoch: 143, Loss: 0.2895, Val: 0.7920, Test: 0.8030
Epoch: 144, Loss: 0.3078, Val: 0.7840, Test: 0.8000
Epoch: 145, Loss: 0.2731, Val: 0.7900, Test: 0.8010
Epoch: 146, Loss: 0.3194, Val: 0.7940, Test: 0.8050
Epoch: 147, Loss: 0.2963, Val: 0.7960, Test: 0.8100
Epoch: 148, Loss: 0.3204, Val: 0.7940, Test: 0.8090
Epoch: 149, Loss: 0.3559, Val: 0.7960, Test: 0.8070
Epoch: 150, Loss: 0.2554, Val: 0.8000, Test: 0.8140
Epoch: 151, Loss: 0.3017, Val: 0.7980, Test: 0.8180
Epoch: 152, Loss: 0.3079, Val: 0.7940, Test: 0.8170
Epoch: 153, Loss: 0.3062, Val: 0.7980, Test: 0.8160
Epoch: 154, Loss: 0.3042, Val: 0.8040, Test: 0.8190
Epoch: 155, Loss: 0.3038, Val: 0.8020, Test: 0.8190
Epoch: 156, Loss: 0.2866, Val: 0.8040, Test: 0.8240
Epoch: 157, Loss: 0.2984, Val: 0.8100, Test: 0.8200
Epoch: 158, Loss: 0.2709, Val: 0.8080, Test: 0.8180
Epoch: 159, Loss: 0.2997, Val: 0.8000, Test: 0.8130
Epoch: 160, Loss: 0.2584, Val: 0.7980, Test: 0.8050
Epoch: 161, Loss: 0.2640, Val: 0.7900, Test: 0.8050
Epoch: 162, Loss: 0.2844, Val: 0.7920, Test: 0.8050
Epoch: 163, Loss: 0.2547, Val: 0.7900, Test: 0.8030
Epoch: 164, Loss: 0.2942, Val: 0.7920, Test: 0.8080
Epoch: 165, Loss: 0.2942, Val: 0.8000, Test: 0.8060
Epoch: 166, Loss: 0.3095, Val: 0.8000, Test: 0.8120
Epoch: 167, Loss: 0.2733, Val: 0.8040, Test: 0.8130
Epoch: 168, Loss: 0.2665, Val: 0.8040, Test: 0.8170
Epoch: 169, Loss: 0.2579, Val: 0.8040, Test: 0.8180
Epoch: 170, Loss: 0.2518, Val: 0.8080, Test: 0.8200
Epoch: 171, Loss: 0.2893, Val: 0.8080, Test: 0.8140
Epoch: 172, Loss: 0.2103, Val: 0.8060, Test: 0.8150
Epoch: 173, Loss: 0.2444, Val: 0.7980, Test: 0.8130
Epoch: 174, Loss: 0.2643, Val: 0.7960, Test: 0.8100
Epoch: 175, Loss: 0.2651, Val: 0.7980, Test: 0.8100
Epoch: 176, Loss: 0.2647, Val: 0.7980, Test: 0.8120
Epoch: 177, Loss: 0.2929, Val: 0.7980, Test: 0.8120
Epoch: 178, Loss: 0.2585, Val: 0.7880, Test: 0.8070
Epoch: 179, Loss: 0.2346, Val: 0.7880, Test: 0.8070
Epoch: 180, Loss: 0.2326, Val: 0.7900, Test: 0.8070
Epoch: 181, Loss: 0.2451, Val: 0.7820, Test: 0.8050
Epoch: 182, Loss: 0.2817, Val: 0.7820, Test: 0.8080
Epoch: 183, Loss: 0.2990, Val: 0.7840, Test: 0.8100
Epoch: 184, Loss: 0.2523, Val: 0.7920, Test: 0.8150
Epoch: 185, Loss: 0.2750, Val: 0.7880, Test: 0.8100
Epoch: 186, Loss: 0.2590, Val: 0.7860, Test: 0.8060
Epoch: 187, Loss: 0.2113, Val: 0.7980, Test: 0.8080
Epoch: 188, Loss: 0.2691, Val: 0.7980, Test: 0.8040
Epoch: 189, Loss: 0.2314, Val: 0.8020, Test: 0.8110
Epoch: 190, Loss: 0.2723, Val: 0.8060, Test: 0.8120
Epoch: 191, Loss: 0.2595, Val: 0.7980, Test: 0.8150
Epoch: 192, Loss: 0.2368, Val: 0.7980, Test: 0.8160
Epoch: 193, Loss: 0.2448, Val: 0.8060, Test: 0.8170
Epoch: 194, Loss: 0.2593, Val: 0.8160, Test: 0.8220
Epoch: 195, Loss: 0.2828, Val: 0.8100, Test: 0.8250
Epoch: 196, Loss: 0.2723, Val: 0.8020, Test: 0.8230
Epoch: 197, Loss: 0.2406, Val: 0.7980, Test: 0.8190
Epoch: 198, Loss: 0.2319, Val: 0.7760, Test: 0.8140
Epoch: 199, Loss: 0.2186, Val: 0.7680, Test: 0.8020
Epoch: 200, Loss: 0.2683, Val: 0.7680, Test: 0.8020
[14]:
test_acc = test(data.test_mask)
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.8020