8. Node Classification with W&B¶
[ ]:
enable_wandb = True
if enable_wandb:
!pip install -qqq wandb
import wandb
wandb.login()
[4]:
# Install required packages.
import os
import pdb
import torch
import pandas
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize(h, color):
z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
plt.figure(figsize=(10,10))
plt.xticks([])
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
plt.show()
def embedding_to_wandb(h, color, key="embedding"):
num_components = h.shape[-1]
df = pandas.DataFrame(data=h.detach().cpu().numpy(),
columns=[f"c_{i}" for i in range(num_components)])
df["target"] = color.detach().cpu().numpy().astype("str")
cols = df.columns.tolist()
df = df[cols[-1:] + cols[:-1]]
wandb.log({key: df})
1.11.0
Node Classification with Graph Neural Networks¶
Previous: Introduction: Hands-on Graph Neural Networks
This tutorial will teach you how to apply Graph Neural Networks (GNNs) to the task of node classification. Here, we are given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning).
To demonstrate, we make use of the Cora dataset, which is a citation network where nodes represent documents. Each node is described by a 1433-dimensional bag-of-words feature vector. Two documents are connected if there exists a citation link between them. The task is to infer the category of each document (7 in total).
This dataset was first introduced by Yang et al. (2016) as one of the datasets of the Planetoid benchmark suite. We again can make use PyTorch Geometric for an easy access to this dataset via `torch_geometric.datasets.Planetoid <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.Planetoid>`__:
[5]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('===========================================================================================================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
===========================================================================================================
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.90
Number of training nodes: 140
Training node label rate: 0.05
Has isolated nodes: False
Has self-loops: False
Is undirected: True
Processing...
Done!
Overall, this dataset is quite similar to the previously used `KarateClub <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.KarateClub>`__ network. We can see that the Cora network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). This results in a training node label rate of only 5%.
In contrast to KarateClub, this graph holds the additional attributes val_mask and test_mask, which denotes which nodes should be used for validation and testing. Furthermore, we make use of data transformationsvia ``transform=NormalizeFeatures()``. Transforms can be used to modify your input data before inputting them into a neural network, e.g., for normalization or data
augmentation. Here, we row-normalize the bag-of-words input feature vectors.
We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
Training a Multi-layer Perception Network (MLP)¶
In theory, we should be able to infer the category of a document solely based on its content, i.e. its bag-of-words feature representation, without taking any relational information into account.
Let’s verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
[6]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
class MLP(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
torch.manual_seed(12345)
self.lin1 = Linear(dataset.num_features, hidden_channels)
self.lin2 = Linear(hidden_channels, dataset.num_classes)
def forward(self, x):
x = self.lin1(x)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)
return x
model = MLP(hidden_channels=16)
print(model)
MLP(
(lin1): Linear(in_features=1433, out_features=16, bias=True)
(lin2): Linear(in_features=16, out_features=7, bias=True)
)
(optionally) logging the data attributes to W&B summary.
[ ]:
if enable_wandb:
wandb.init(project='node-classification')
summary = dict()
summary["data"] = dict()
summary["data"]["num_features"] = dataset.num_features
summary["data"]["num_classes"] = dataset.num_classes
summary["data"]["num_nodes"] = data.num_nodes
summary["data"]["num_edges"] = data.num_edges
summary["data"]["has_isolated_nodes"] = data.has_isolated_nodes()
summary["data"]["has_self_nodes"] = data.has_self_loops()
summary["data"]["is_undirected"] = data.is_undirected()
summary["data"]["num_training_nodes"] = data.train_mask.sum()
wandb.summary = summary
Our MLP is defined by two linear layers and enhanced by ReLU non-linearity and dropout. Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (hidden_channels=16), while the second linear layer acts as a classifier that should map each low-dimensional node
embedding to one of the 7 classes.
Let’s train our simple MLP by following a similar procedure as described in the first part of this tutorial. We again make use of the cross entropy loss and Adam optimizer. This time, we also define a ``test`` function to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
We also visualize the embeddings of the untrained model to in visually comparing the progress made by the training process below.
NOTE: For W&B mode, please set up the embedding projector from the setting panel of the logged table. More information can be found here: https://docs.wandb.ai/ref/app/features/panels/weave/embedding-projector
[8]:
# from IPython.display import Javascript # Restrict height of output cell.
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = MLP(hidden_channels=16)
with torch.no_grad():
out = model(data.x)
# def embedding_to_wandb(h, color, key="embedding"):
# num_components = h.shape[-1]
# df = pandas.DataFrame(data=h.detach().cpu().numpy(),
# columns=[f"c_{i}" for i in range(num_components)])
# df["target"] = color.detach().cpu().numpy().astype("str")
# cols = df.columns.tolist()
# df = df[cols[-1:] + cols[:-1]]
# wandb.log({key: df})
if enable_wandb:
embedding_to_wandb(out, color=data.y, key="mlp/embedding/init")
else:
visualize(out, data.y)
criterion = torch.nn.CrossEntropyLoss() # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Define optimizer.
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
for epoch in range(1, 201):
loss = train()
if enable_wandb:
wandb.log({"mlp/loss": loss})
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
Epoch: 001, Loss: 1.9617
Epoch: 002, Loss: 1.9562
Epoch: 003, Loss: 1.9499
Epoch: 004, Loss: 1.9401
Epoch: 005, Loss: 1.9357
Epoch: 006, Loss: 1.9253
Epoch: 007, Loss: 1.9144
Epoch: 008, Loss: 1.9104
Epoch: 009, Loss: 1.8983
Epoch: 010, Loss: 1.8885
Epoch: 011, Loss: 1.8817
Epoch: 012, Loss: 1.8627
Epoch: 013, Loss: 1.8596
Epoch: 014, Loss: 1.8287
Epoch: 015, Loss: 1.8293
Epoch: 016, Loss: 1.8043
Epoch: 017, Loss: 1.8166
Epoch: 018, Loss: 1.7803
Epoch: 019, Loss: 1.7525
Epoch: 020, Loss: 1.7318
Epoch: 021, Loss: 1.7636
Epoch: 022, Loss: 1.7090
Epoch: 023, Loss: 1.7026
Epoch: 024, Loss: 1.6821
Epoch: 025, Loss: 1.6637
Epoch: 026, Loss: 1.6404
Epoch: 027, Loss: 1.6124
Epoch: 028, Loss: 1.5674
Epoch: 029, Loss: 1.5422
Epoch: 030, Loss: 1.5519
Epoch: 031, Loss: 1.5399
Epoch: 032, Loss: 1.4822
Epoch: 033, Loss: 1.4956
Epoch: 034, Loss: 1.4479
Epoch: 035, Loss: 1.4063
Epoch: 036, Loss: 1.4378
Epoch: 037, Loss: 1.3867
Epoch: 038, Loss: 1.3802
Epoch: 039, Loss: 1.3393
Epoch: 040, Loss: 1.2952
Epoch: 041, Loss: 1.2643
Epoch: 042, Loss: 1.2424
Epoch: 043, Loss: 1.2596
Epoch: 044, Loss: 1.1963
Epoch: 045, Loss: 1.2004
Epoch: 046, Loss: 1.1599
Epoch: 047, Loss: 1.1141
Epoch: 048, Loss: 1.1428
Epoch: 049, Loss: 1.0960
Epoch: 050, Loss: 1.0473
Epoch: 051, Loss: 1.0538
Epoch: 052, Loss: 1.0431
Epoch: 053, Loss: 1.0650
Epoch: 054, Loss: 0.9798
Epoch: 055, Loss: 0.9571
Epoch: 056, Loss: 0.8945
Epoch: 057, Loss: 0.9230
Epoch: 058, Loss: 0.9363
Epoch: 059, Loss: 0.9066
Epoch: 060, Loss: 0.8500
Epoch: 061, Loss: 0.8906
Epoch: 062, Loss: 0.8880
Epoch: 063, Loss: 0.8962
Epoch: 064, Loss: 0.8728
Epoch: 065, Loss: 0.8039
Epoch: 066, Loss: 0.8037
Epoch: 067, Loss: 0.7401
Epoch: 068, Loss: 0.7795
Epoch: 069, Loss: 0.7854
Epoch: 070, Loss: 0.7001
Epoch: 071, Loss: 0.7158
Epoch: 072, Loss: 0.8123
Epoch: 073, Loss: 0.7777
Epoch: 074, Loss: 0.7148
Epoch: 075, Loss: 0.7087
Epoch: 076, Loss: 0.7172
Epoch: 077, Loss: 0.7077
Epoch: 078, Loss: 0.6799
Epoch: 079, Loss: 0.6721
Epoch: 080, Loss: 0.6176
Epoch: 081, Loss: 0.7527
Epoch: 082, Loss: 0.7109
Epoch: 083, Loss: 0.6338
Epoch: 084, Loss: 0.6398
Epoch: 085, Loss: 0.5915
Epoch: 086, Loss: 0.5672
Epoch: 087, Loss: 0.5731
Epoch: 088, Loss: 0.5745
Epoch: 089, Loss: 0.6553
Epoch: 090, Loss: 0.5983
Epoch: 091, Loss: 0.6254
Epoch: 092, Loss: 0.6545
Epoch: 093, Loss: 0.6582
Epoch: 094, Loss: 0.5903
Epoch: 095, Loss: 0.5522
Epoch: 096, Loss: 0.5503
Epoch: 097, Loss: 0.6268
Epoch: 098, Loss: 0.5739
Epoch: 099, Loss: 0.5782
Epoch: 100, Loss: 0.5927
Epoch: 101, Loss: 0.5283
Epoch: 102, Loss: 0.5436
Epoch: 103, Loss: 0.6349
Epoch: 104, Loss: 0.5427
Epoch: 105, Loss: 0.5791
Epoch: 106, Loss: 0.5296
Epoch: 107, Loss: 0.6114
Epoch: 108, Loss: 0.5577
Epoch: 109, Loss: 0.5630
Epoch: 110, Loss: 0.5164
Epoch: 111, Loss: 0.5256
Epoch: 112, Loss: 0.5677
Epoch: 113, Loss: 0.5432
Epoch: 114, Loss: 0.4872
Epoch: 115, Loss: 0.5720
Epoch: 116, Loss: 0.4607
Epoch: 117, Loss: 0.4998
Epoch: 118, Loss: 0.4805
Epoch: 119, Loss: 0.5122
Epoch: 120, Loss: 0.5538
Epoch: 121, Loss: 0.4706
Epoch: 122, Loss: 0.5497
Epoch: 123, Loss: 0.5311
Epoch: 124, Loss: 0.5336
Epoch: 125, Loss: 0.5322
Epoch: 126, Loss: 0.5197
Epoch: 127, Loss: 0.4824
Epoch: 128, Loss: 0.4711
Epoch: 129, Loss: 0.4659
Epoch: 130, Loss: 0.5088
Epoch: 131, Loss: 0.4876
Epoch: 132, Loss: 0.4886
Epoch: 133, Loss: 0.5317
Epoch: 134, Loss: 0.4520
Epoch: 135, Loss: 0.4595
Epoch: 136, Loss: 0.4849
Epoch: 137, Loss: 0.5109
Epoch: 138, Loss: 0.4608
Epoch: 139, Loss: 0.5431
Epoch: 140, Loss: 0.5599
Epoch: 141, Loss: 0.5580
Epoch: 142, Loss: 0.4482
Epoch: 143, Loss: 0.4705
Epoch: 144, Loss: 0.5860
Epoch: 145, Loss: 0.5049
Epoch: 146, Loss: 0.4410
Epoch: 147, Loss: 0.4657
Epoch: 148, Loss: 0.5248
Epoch: 149, Loss: 0.4939
Epoch: 150, Loss: 0.4237
Epoch: 151, Loss: 0.4926
Epoch: 152, Loss: 0.4278
Epoch: 153, Loss: 0.4470
Epoch: 154, Loss: 0.4876
Epoch: 155, Loss: 0.4651
Epoch: 156, Loss: 0.4594
Epoch: 157, Loss: 0.4317
Epoch: 158, Loss: 0.4761
Epoch: 159, Loss: 0.3911
Epoch: 160, Loss: 0.4328
Epoch: 161, Loss: 0.4612
Epoch: 162, Loss: 0.4156
Epoch: 163, Loss: 0.3770
Epoch: 164, Loss: 0.4313
Epoch: 165, Loss: 0.4490
Epoch: 166, Loss: 0.4709
Epoch: 167, Loss: 0.4525
Epoch: 168, Loss: 0.4173
Epoch: 169, Loss: 0.4336
Epoch: 170, Loss: 0.4264
Epoch: 171, Loss: 0.4419
Epoch: 172, Loss: 0.3597
Epoch: 173, Loss: 0.3726
Epoch: 174, Loss: 0.4015
Epoch: 175, Loss: 0.4227
Epoch: 176, Loss: 0.4425
Epoch: 177, Loss: 0.4166
Epoch: 178, Loss: 0.3845
Epoch: 179, Loss: 0.4365
Epoch: 180, Loss: 0.4116
Epoch: 181, Loss: 0.3971
Epoch: 182, Loss: 0.3818
Epoch: 183, Loss: 0.4135
Epoch: 184, Loss: 0.3317
Epoch: 185, Loss: 0.4355
Epoch: 186, Loss: 0.4069
Epoch: 187, Loss: 0.4238
Epoch: 188, Loss: 0.4251
Epoch: 189, Loss: 0.4337
Epoch: 190, Loss: 0.3434
Epoch: 191, Loss: 0.3861
Epoch: 192, Loss: 0.4183
Epoch: 193, Loss: 0.3450
Epoch: 194, Loss: 0.3702
Epoch: 195, Loss: 0.4222
Epoch: 196, Loss: 0.4183
Epoch: 197, Loss: 0.3655
Epoch: 198, Loss: 0.4121
Epoch: 199, Loss: 0.4245
Epoch: 200, Loss: 0.3709
After training the model, we can call the test function to see how well our model performs on unseen labels. Here, we are interested in the accuracy of the model, i.e., the ratio of correctly classified nodes:
We also visualize the embeddings of the output. This will give us a visual hint as to how good the model is performing, when compared to the embeddings of the geometric models defined below.
[9]:
test_acc = test()
out = model(data.x)
if enable_wandb:
embedding_to_wandb(out, color=data.y, key="mlp/embedding/trained")
wandb.summary["mlp/accuracy"] = test_acc
wandb.log({"mlp/accuracy": test_acc})
else:
visualize(out, data.y)
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.5920
As one can see, our MLP performs rather bad with only about 59% test accuracy. But why does the MLP do not perform better? The main reason for that is that this model suffers from heavy overfitting due to only having access to a small amount of training nodes, and therefore generalizes poorly to unseen node representations.
It also fails to incorporate an important bias into the model: Cited papers are very likely related to the category of a document. That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
Training a Graph Neural Network (GNN)¶
We can easily convert our MLP to a GNN by swapping the torch.nn.Linear layers with PyG’s GNN operators.
Following-up on the first part of this tutorial, we replace the linear layers by the `GCNConv <https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv>`__ module. To recap, the GCN layer (Kipf et al. (2017)) is defined as
where \(\mathbf{W}^{(\ell + 1)}\) denotes a trainable weight matrix of shape [num_output_features, num_input_features] and \(c_{w,v}\) refers to a fixed normalization coefficient for each edge. In contrast, a single Linear layer is defined as
which does not make use of neighboring node information.
[10]:
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GCN(hidden_channels=16)
print(model)
GCN(
(conv1): GCNConv(1433, 16)
(conv2): GCNConv(16, 7)
)
Let’s visualize the node embeddings of our untrained GCN network. For visualization, we make use of TSNE to embed our 7-dimensional node embeddings onto a 2D plane.
[11]:
model = GCN(hidden_channels=16)
model.eval()
out = model(data.x, data.edge_index)
if enable_wandb:
embedding_to_wandb(out, color=data.y, key="gcn/embedding/init")
else:
visualize(out, data.y)
We certainly can do better by training our model. The training and testing procedure is once again the same, but this time we make use of the node features x and the graph connectivity edge_index as input to our GCN model.
[12]:
# from IPython.display import Javascript # Restrict height of output cell.
# display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GCN(hidden_channels=16)
if enable_wandb:
wandb.watch(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
for epoch in range(1, 101):
loss = train()
if enable_wandb:
wandb.log({"gcn/loss": loss})
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
Epoch: 001, Loss: 1.9465
Epoch: 002, Loss: 1.9419
Epoch: 003, Loss: 1.9363
Epoch: 004, Loss: 1.9290
Epoch: 005, Loss: 1.9199
Epoch: 006, Loss: 1.9140
Epoch: 007, Loss: 1.9079
Epoch: 008, Loss: 1.8992
Epoch: 009, Loss: 1.8876
Epoch: 010, Loss: 1.8764
Epoch: 011, Loss: 1.8656
Epoch: 012, Loss: 1.8626
Epoch: 013, Loss: 1.8460
Epoch: 014, Loss: 1.8329
Epoch: 015, Loss: 1.8225
Epoch: 016, Loss: 1.8167
Epoch: 017, Loss: 1.7995
Epoch: 018, Loss: 1.7878
Epoch: 019, Loss: 1.7716
Epoch: 020, Loss: 1.7568
Epoch: 021, Loss: 1.7563
Epoch: 022, Loss: 1.7342
Epoch: 023, Loss: 1.7092
Epoch: 024, Loss: 1.7015
Epoch: 025, Loss: 1.6671
Epoch: 026, Loss: 1.6757
Epoch: 027, Loss: 1.6609
Epoch: 028, Loss: 1.6355
Epoch: 029, Loss: 1.6339
Epoch: 030, Loss: 1.6102
Epoch: 031, Loss: 1.5964
Epoch: 032, Loss: 1.5721
Epoch: 033, Loss: 1.5570
Epoch: 034, Loss: 1.5445
Epoch: 035, Loss: 1.5093
Epoch: 036, Loss: 1.4889
Epoch: 037, Loss: 1.4776
Epoch: 038, Loss: 1.4704
Epoch: 039, Loss: 1.4263
Epoch: 040, Loss: 1.3972
Epoch: 041, Loss: 1.3873
Epoch: 042, Loss: 1.3479
Epoch: 043, Loss: 1.3485
Epoch: 044, Loss: 1.3739
Epoch: 045, Loss: 1.3343
Epoch: 046, Loss: 1.3277
Epoch: 047, Loss: 1.2770
Epoch: 048, Loss: 1.2651
Epoch: 049, Loss: 1.2347
Epoch: 050, Loss: 1.2543
Epoch: 051, Loss: 1.1622
Epoch: 052, Loss: 1.1483
Epoch: 053, Loss: 1.1535
Epoch: 054, Loss: 1.1912
Epoch: 055, Loss: 1.0880
Epoch: 056, Loss: 1.1374
Epoch: 057, Loss: 1.0657
Epoch: 058, Loss: 1.0748
Epoch: 059, Loss: 1.0654
Epoch: 060, Loss: 1.0201
Epoch: 061, Loss: 0.9967
Epoch: 062, Loss: 1.0499
Epoch: 063, Loss: 1.0116
Epoch: 064, Loss: 0.9945
Epoch: 065, Loss: 0.9499
Epoch: 066, Loss: 0.9465
Epoch: 067, Loss: 0.9633
Epoch: 068, Loss: 0.9137
Epoch: 069, Loss: 0.9168
Epoch: 070, Loss: 0.8818
Epoch: 071, Loss: 0.8984
Epoch: 072, Loss: 0.8301
Epoch: 073, Loss: 0.8664
Epoch: 074, Loss: 0.8560
Epoch: 075, Loss: 0.8457
Epoch: 076, Loss: 0.8306
Epoch: 077, Loss: 0.8333
Epoch: 078, Loss: 0.8155
Epoch: 079, Loss: 0.7878
Epoch: 080, Loss: 0.8277
Epoch: 081, Loss: 0.7880
Epoch: 082, Loss: 0.7829
Epoch: 083, Loss: 0.7829
Epoch: 084, Loss: 0.7633
Epoch: 085, Loss: 0.7862
Epoch: 086, Loss: 0.7497
Epoch: 087, Loss: 0.7760
Epoch: 088, Loss: 0.7419
Epoch: 089, Loss: 0.6595
Epoch: 090, Loss: 0.6746
Epoch: 091, Loss: 0.7432
Epoch: 092, Loss: 0.6609
Epoch: 093, Loss: 0.6607
Epoch: 094, Loss: 0.6884
Epoch: 095, Loss: 0.6596
Epoch: 096, Loss: 0.6456
Epoch: 097, Loss: 0.6383
Epoch: 098, Loss: 0.7031
Epoch: 099, Loss: 0.6437
Epoch: 100, Loss: 0.6375
After training the model, we can check its test accuracy:
[13]:
test_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')
Test Accuracy: 0.8110
There it is! By simply swapping the linear layers with GNN layers, we can reach 81.5% of test accuracy! This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
[ ]:
model.eval()
out = model(data.x, data.edge_index)
if enable_wandb:
wandb.summary["gcn/accuracy"] = test_acc
wandb.log({"gcn/accuracy": test_acc})
embedding_to_wandb(out, color=data.y, key="gcn/embedding/trained")
wandb.finish()
else:
visualize(out, data.y)
Using W&B Sweeps¶
In this section, we’ll look into how we can use W&B Sweeps to perform a hyper-parameter search for the GCN. For this to work, it is essential for wandb to be enabled, i.e., enable_wandb should be set to True.
[15]:
assert enable_wandb, "W&B not enabled. Please, enable W&B and restart the notebook"
[16]:
import tqdm
def agent_fn():
wandb.init()
model = GCN(hidden_channels=wandb.config.hidden_channels)
wandb.watch(model)
with torch.no_grad():
out = model(data.x, data.edge_index)
embedding_to_wandb(out, color=data.y, key="gcn/embedding/init")
optimizer = torch.optim.Adam(model.parameters(),
lr=wandb.config.lr,
weight_decay=wandb.config.weight_decay)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
test_correct = pred[data.test_mask] == data.y[data.test_mask] # Check against ground-truth labels.
test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) # Derive ratio of correct predictions.
return test_acc
for epoch in tqdm.tqdm(range(1, 101)):
loss = train()
wandb.log({"gcn/loss": loss})
model.eval()
out = model(data.x, data.edge_index)
test_acc = test()
wandb.summary["gcn/accuracy"] = test_acc
wandb.log({"gcn/accuracy": test_acc})
embedding_to_wandb(out, color=data.y, key="gcn/embedding/trained")
wandb.finish()
[ ]:
sweep_config = {
"name": "gcn-sweep",
"method": "bayes",
"metric": {
"name": "gcn/accuracy",
"goal": "maximize",
},
"parameters": {
"hidden_channels": {
"values": [8, 16, 32]
},
"weight_decay": {
"distribution": "normal",
"mu": 5e-4,
"sigma": 1e-5,
},
"lr": {
"min": 1e-4,
"max": 1e-3
}
}
}
# Register the Sweep with W&B
sweep_id = wandb.sweep(sweep_config, project="node-classification")
[ ]:
# Run the Sweeps agent
wandb.agent(sweep_id, project="node-classification", function=agent_fn, count=50)
Conclusion¶
In this chapter, you have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model’s performance. In the next section, we will look into how GNNs can be used for the task of graph classification.
(Optional) Exercises¶
To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The
Coradataset provides a validation node set asdata.val_mask, but we haven’t used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to 82% accuracy.How does
GCNbehave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?You can try to use different GNN layers to see how model performance changes. What happens if you swap out all
GCNConvinstances with`GATConv<https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GATConv>`__ layers that make use of attention? Try to write a 2-layerGATmodel that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses adropoutratio of0.6inside and outside eachGATConvcall, and uses ahidden_channelsdimensions of8per head.
[32]:
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, heads):
super().__init__()
torch.manual_seed(1234567)
self.conv1 = GATConv(in_channels=dataset.num_features,
out_channels=hidden_channels,
heads=heads,
dropout=0.6) # TODO
self.conv2 = GATConv(in_channels=hidden_channels* heads,
out_channels=dataset.num_classes,
heads=1,
dropout=0.6) # TODO
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT(hidden_channels=8, heads=8)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
def test(mask):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct = pred[mask] == data.y[mask] # Check against ground-truth labels.
acc = int(correct.sum()) / int(mask.sum()) # Derive ratio of correct predictions.
return acc
for epoch in range(1, 201):
loss = train()
val_acc = test(data.val_mask)
test_acc = test(data.test_mask)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
GAT(
(conv1): GATConv(1433, 8, heads=8)
(conv2): GATConv(64, 7, heads=1)
)
Epoch: 001, Loss: 1.9442, Val: 0.2680, Test: 0.2660
Epoch: 002, Loss: 1.9398, Val: 0.2860, Test: 0.2720
Epoch: 003, Loss: 1.9282, Val: 0.3740, Test: 0.3620
Epoch: 004, Loss: 1.9240, Val: 0.4820, Test: 0.4830
Epoch: 005, Loss: 1.9211, Val: 0.4620, Test: 0.4640
Epoch: 006, Loss: 1.9166, Val: 0.5820, Test: 0.6000
Epoch: 007, Loss: 1.9032, Val: 0.6820, Test: 0.6870
Epoch: 008, Loss: 1.8941, Val: 0.7080, Test: 0.7190
Epoch: 009, Loss: 1.8846, Val: 0.7160, Test: 0.7310
Epoch: 010, Loss: 1.8877, Val: 0.7380, Test: 0.7500
Epoch: 011, Loss: 1.8799, Val: 0.7400, Test: 0.7530
Epoch: 012, Loss: 1.8698, Val: 0.7480, Test: 0.7610
Epoch: 013, Loss: 1.8729, Val: 0.7520, Test: 0.7880
Epoch: 014, Loss: 1.8471, Val: 0.7700, Test: 0.7990
Epoch: 015, Loss: 1.8432, Val: 0.7680, Test: 0.7920
Epoch: 016, Loss: 1.8459, Val: 0.7680, Test: 0.7900
Epoch: 017, Loss: 1.8252, Val: 0.7820, Test: 0.7890
Epoch: 018, Loss: 1.8128, Val: 0.7800, Test: 0.7870
Epoch: 019, Loss: 1.8182, Val: 0.7880, Test: 0.7880
Epoch: 020, Loss: 1.7892, Val: 0.7920, Test: 0.7930
Epoch: 021, Loss: 1.7974, Val: 0.7840, Test: 0.8000
Epoch: 022, Loss: 1.7842, Val: 0.7780, Test: 0.8030
Epoch: 023, Loss: 1.7564, Val: 0.7660, Test: 0.7920
Epoch: 024, Loss: 1.7729, Val: 0.7720, Test: 0.7840
Epoch: 025, Loss: 1.7247, Val: 0.7740, Test: 0.7780
Epoch: 026, Loss: 1.7577, Val: 0.7580, Test: 0.7650
Epoch: 027, Loss: 1.7140, Val: 0.7560, Test: 0.7460
Epoch: 028, Loss: 1.7291, Val: 0.7580, Test: 0.7430
Epoch: 029, Loss: 1.7042, Val: 0.7580, Test: 0.7370
Epoch: 030, Loss: 1.6777, Val: 0.7580, Test: 0.7320
Epoch: 031, Loss: 1.7013, Val: 0.7640, Test: 0.7400
Epoch: 032, Loss: 1.6875, Val: 0.7700, Test: 0.7480
Epoch: 033, Loss: 1.6333, Val: 0.7760, Test: 0.7580
Epoch: 034, Loss: 1.6270, Val: 0.7820, Test: 0.7780
Epoch: 035, Loss: 1.5962, Val: 0.7820, Test: 0.7790
Epoch: 036, Loss: 1.6522, Val: 0.7920, Test: 0.7880
Epoch: 037, Loss: 1.5936, Val: 0.7940, Test: 0.7930
Epoch: 038, Loss: 1.6122, Val: 0.8020, Test: 0.8030
Epoch: 039, Loss: 1.5352, Val: 0.8060, Test: 0.8080
Epoch: 040, Loss: 1.5734, Val: 0.8080, Test: 0.8150
Epoch: 041, Loss: 1.4901, Val: 0.8080, Test: 0.8180
Epoch: 042, Loss: 1.5120, Val: 0.8040, Test: 0.8190
Epoch: 043, Loss: 1.4829, Val: 0.8000, Test: 0.8230
Epoch: 044, Loss: 1.5125, Val: 0.8020, Test: 0.8230
Epoch: 045, Loss: 1.5164, Val: 0.8040, Test: 0.8230
Epoch: 046, Loss: 1.4329, Val: 0.8100, Test: 0.8210
Epoch: 047, Loss: 1.4720, Val: 0.8080, Test: 0.8200
Epoch: 048, Loss: 1.4141, Val: 0.8020, Test: 0.8160
Epoch: 049, Loss: 1.4266, Val: 0.8020, Test: 0.8170
Epoch: 050, Loss: 1.4410, Val: 0.7980, Test: 0.8170
Epoch: 051, Loss: 1.4411, Val: 0.7980, Test: 0.8160
Epoch: 052, Loss: 1.3638, Val: 0.7960, Test: 0.8170
Epoch: 053, Loss: 1.4005, Val: 0.7920, Test: 0.8130
Epoch: 054, Loss: 1.3491, Val: 0.7900, Test: 0.8090
Epoch: 055, Loss: 1.3265, Val: 0.7920, Test: 0.8110
Epoch: 056, Loss: 1.4307, Val: 0.7940, Test: 0.8020
Epoch: 057, Loss: 1.2772, Val: 0.7920, Test: 0.8040
Epoch: 058, Loss: 1.3646, Val: 0.7880, Test: 0.8020
Epoch: 059, Loss: 1.2733, Val: 0.7880, Test: 0.8010
Epoch: 060, Loss: 1.2890, Val: 0.7860, Test: 0.7990
Epoch: 061, Loss: 1.2712, Val: 0.7840, Test: 0.8010
Epoch: 062, Loss: 1.3131, Val: 0.7800, Test: 0.7990
Epoch: 063, Loss: 1.2414, Val: 0.7860, Test: 0.8000
Epoch: 064, Loss: 1.2593, Val: 0.7960, Test: 0.7980
Epoch: 065, Loss: 1.2376, Val: 0.7960, Test: 0.8000
Epoch: 066, Loss: 1.2993, Val: 0.7900, Test: 0.7980
Epoch: 067, Loss: 1.1746, Val: 0.7880, Test: 0.8030
Epoch: 068, Loss: 1.2543, Val: 0.7900, Test: 0.8020
Epoch: 069, Loss: 1.2183, Val: 0.7940, Test: 0.8020
Epoch: 070, Loss: 1.1632, Val: 0.7940, Test: 0.7990
Epoch: 071, Loss: 1.1631, Val: 0.7980, Test: 0.8030
Epoch: 072, Loss: 1.0818, Val: 0.7980, Test: 0.7970
Epoch: 073, Loss: 1.1275, Val: 0.7980, Test: 0.7980
Epoch: 074, Loss: 1.0822, Val: 0.7980, Test: 0.7990
Epoch: 075, Loss: 1.1707, Val: 0.8000, Test: 0.8000
Epoch: 076, Loss: 1.1509, Val: 0.8040, Test: 0.8020
Epoch: 077, Loss: 1.0767, Val: 0.8020, Test: 0.8020
Epoch: 078, Loss: 1.1139, Val: 0.8020, Test: 0.8010
Epoch: 079, Loss: 1.1518, Val: 0.7980, Test: 0.8000
Epoch: 080, Loss: 1.0266, Val: 0.7940, Test: 0.8020
Epoch: 081, Loss: 1.0675, Val: 0.7920, Test: 0.8020
Epoch: 082, Loss: 1.0870, Val: 0.7920, Test: 0.8000
Epoch: 083, Loss: 1.0169, Val: 0.7900, Test: 0.8010
Epoch: 084, Loss: 0.9710, Val: 0.7920, Test: 0.8020
Epoch: 085, Loss: 1.1048, Val: 0.7920, Test: 0.8000
Epoch: 086, Loss: 1.0088, Val: 0.7920, Test: 0.8020
Epoch: 087, Loss: 0.9880, Val: 0.7940, Test: 0.8050
Epoch: 088, Loss: 1.0503, Val: 0.7980, Test: 0.8040
Epoch: 089, Loss: 1.0239, Val: 0.7960, Test: 0.8040
Epoch: 090, Loss: 1.0279, Val: 0.7920, Test: 0.8070
Epoch: 091, Loss: 1.0774, Val: 0.7900, Test: 0.8070
Epoch: 092, Loss: 0.9331, Val: 0.7940, Test: 0.8080
Epoch: 093, Loss: 0.9973, Val: 0.7900, Test: 0.8070
Epoch: 094, Loss: 0.9743, Val: 0.7940, Test: 0.8070
Epoch: 095, Loss: 0.9499, Val: 0.7940, Test: 0.8100
Epoch: 096, Loss: 1.0062, Val: 0.7960, Test: 0.8060
Epoch: 097, Loss: 0.9607, Val: 0.8020, Test: 0.8120
Epoch: 098, Loss: 0.9859, Val: 0.8040, Test: 0.8130
Epoch: 099, Loss: 0.9307, Val: 0.8060, Test: 0.8120
Epoch: 100, Loss: 0.9841, Val: 0.8140, Test: 0.8120
Epoch: 101, Loss: 1.0277, Val: 0.8140, Test: 0.8080
Epoch: 102, Loss: 0.9114, Val: 0.8200, Test: 0.8090
Epoch: 103, Loss: 0.9208, Val: 0.8160, Test: 0.8150
Epoch: 104, Loss: 0.9161, Val: 0.8180, Test: 0.8140
Epoch: 105, Loss: 0.8903, Val: 0.8140, Test: 0.8140
Epoch: 106, Loss: 0.9568, Val: 0.8140, Test: 0.8160
Epoch: 107, Loss: 0.9000, Val: 0.8060, Test: 0.8160
Epoch: 108, Loss: 0.8524, Val: 0.8060, Test: 0.8170
Epoch: 109, Loss: 1.0156, Val: 0.8060, Test: 0.8160
Epoch: 110, Loss: 0.9192, Val: 0.8020, Test: 0.8120
Epoch: 111, Loss: 0.9475, Val: 0.7980, Test: 0.8150
Epoch: 112, Loss: 0.9696, Val: 0.8000, Test: 0.8140
Epoch: 113, Loss: 0.8096, Val: 0.8000, Test: 0.8150
Epoch: 114, Loss: 0.8566, Val: 0.8020, Test: 0.8130
Epoch: 115, Loss: 0.8836, Val: 0.8040, Test: 0.8150
Epoch: 116, Loss: 0.8872, Val: 0.8040, Test: 0.8140
Epoch: 117, Loss: 0.8857, Val: 0.8020, Test: 0.8150
Epoch: 118, Loss: 0.8941, Val: 0.8060, Test: 0.8190
Epoch: 119, Loss: 0.8946, Val: 0.8080, Test: 0.8200
Epoch: 120, Loss: 0.8932, Val: 0.8060, Test: 0.8170
Epoch: 121, Loss: 0.8594, Val: 0.8040, Test: 0.8160
Epoch: 122, Loss: 0.9081, Val: 0.8060, Test: 0.8160
Epoch: 123, Loss: 0.8412, Val: 0.8060, Test: 0.8160
Epoch: 124, Loss: 0.8700, Val: 0.8100, Test: 0.8190
Epoch: 125, Loss: 0.7985, Val: 0.8080, Test: 0.8220
Epoch: 126, Loss: 0.7417, Val: 0.8060, Test: 0.8230
Epoch: 127, Loss: 0.8938, Val: 0.8080, Test: 0.8240
Epoch: 128, Loss: 0.8232, Val: 0.8060, Test: 0.8250
Epoch: 129, Loss: 0.9651, Val: 0.8080, Test: 0.8260
Epoch: 130, Loss: 0.7675, Val: 0.8100, Test: 0.8240
Epoch: 131, Loss: 0.9292, Val: 0.8100, Test: 0.8220
Epoch: 132, Loss: 0.7900, Val: 0.8080, Test: 0.8200
Epoch: 133, Loss: 0.7838, Val: 0.8060, Test: 0.8170
Epoch: 134, Loss: 0.7953, Val: 0.8060, Test: 0.8140
Epoch: 135, Loss: 0.9099, Val: 0.8060, Test: 0.8130
Epoch: 136, Loss: 0.8378, Val: 0.8040, Test: 0.8090
Epoch: 137, Loss: 0.8383, Val: 0.7960, Test: 0.8090
Epoch: 138, Loss: 0.7474, Val: 0.7940, Test: 0.8110
Epoch: 139, Loss: 0.8183, Val: 0.7920, Test: 0.8070
Epoch: 140, Loss: 0.8193, Val: 0.7920, Test: 0.8050
Epoch: 141, Loss: 0.8114, Val: 0.7940, Test: 0.8030
Epoch: 142, Loss: 0.8264, Val: 0.7920, Test: 0.8030
Epoch: 143, Loss: 0.7286, Val: 0.7960, Test: 0.8040
Epoch: 144, Loss: 0.8243, Val: 0.7980, Test: 0.8050
Epoch: 145, Loss: 0.8658, Val: 0.8000, Test: 0.8050
Epoch: 146, Loss: 0.8029, Val: 0.8020, Test: 0.8070
Epoch: 147, Loss: 0.8127, Val: 0.8000, Test: 0.8090
Epoch: 148, Loss: 0.7576, Val: 0.8060, Test: 0.8130
Epoch: 149, Loss: 0.6674, Val: 0.8080, Test: 0.8210
Epoch: 150, Loss: 0.8316, Val: 0.8140, Test: 0.8260
Epoch: 151, Loss: 0.8275, Val: 0.8180, Test: 0.8270
Epoch: 152, Loss: 0.8434, Val: 0.8160, Test: 0.8290
Epoch: 153, Loss: 0.9301, Val: 0.8140, Test: 0.8310
Epoch: 154, Loss: 0.7763, Val: 0.8120, Test: 0.8310
Epoch: 155, Loss: 0.7662, Val: 0.8100, Test: 0.8310
Epoch: 156, Loss: 0.7997, Val: 0.8060, Test: 0.8290
Epoch: 157, Loss: 0.7296, Val: 0.8040, Test: 0.8240
Epoch: 158, Loss: 0.7622, Val: 0.8000, Test: 0.8220
Epoch: 159, Loss: 0.8144, Val: 0.7940, Test: 0.8180
Epoch: 160, Loss: 0.7343, Val: 0.7960, Test: 0.8140
Epoch: 161, Loss: 0.7731, Val: 0.7920, Test: 0.8120
Epoch: 162, Loss: 0.7364, Val: 0.7840, Test: 0.8080
Epoch: 163, Loss: 0.7949, Val: 0.7900, Test: 0.8070
Epoch: 164, Loss: 0.7287, Val: 0.7880, Test: 0.8050
Epoch: 165, Loss: 0.7974, Val: 0.7860, Test: 0.7970
Epoch: 166, Loss: 0.6848, Val: 0.7900, Test: 0.7970
Epoch: 167, Loss: 0.6555, Val: 0.7900, Test: 0.8000
Epoch: 168, Loss: 0.6958, Val: 0.7900, Test: 0.8020
Epoch: 169, Loss: 0.8070, Val: 0.7920, Test: 0.8030
Epoch: 170, Loss: 0.7806, Val: 0.8020, Test: 0.8050
Epoch: 171, Loss: 0.8483, Val: 0.8020, Test: 0.8090
Epoch: 172, Loss: 0.7004, Val: 0.8040, Test: 0.8090
Epoch: 173, Loss: 0.7916, Val: 0.8040, Test: 0.8140
Epoch: 174, Loss: 0.7581, Val: 0.8100, Test: 0.8130
Epoch: 175, Loss: 0.6753, Val: 0.8080, Test: 0.8180
Epoch: 176, Loss: 0.8322, Val: 0.8080, Test: 0.8210
Epoch: 177, Loss: 0.6945, Val: 0.8100, Test: 0.8230
Epoch: 178, Loss: 0.8020, Val: 0.8120, Test: 0.8280
Epoch: 179, Loss: 0.6372, Val: 0.8140, Test: 0.8290
Epoch: 180, Loss: 0.7657, Val: 0.8100, Test: 0.8260
Epoch: 181, Loss: 0.8001, Val: 0.8080, Test: 0.8260
Epoch: 182, Loss: 0.7520, Val: 0.8080, Test: 0.8290
Epoch: 183, Loss: 0.6406, Val: 0.8040, Test: 0.8260
Epoch: 184, Loss: 0.7105, Val: 0.8060, Test: 0.8220
Epoch: 185, Loss: 0.8111, Val: 0.8060, Test: 0.8170
Epoch: 186, Loss: 0.7277, Val: 0.7980, Test: 0.8170
Epoch: 187, Loss: 0.7002, Val: 0.7980, Test: 0.8150
Epoch: 188, Loss: 0.6906, Val: 0.7980, Test: 0.8130
Epoch: 189, Loss: 0.7586, Val: 0.7980, Test: 0.8120
Epoch: 190, Loss: 0.7072, Val: 0.7980, Test: 0.8140
Epoch: 191, Loss: 0.7949, Val: 0.7980, Test: 0.8170
Epoch: 192, Loss: 0.6978, Val: 0.7960, Test: 0.8140
Epoch: 193, Loss: 0.7212, Val: 0.7980, Test: 0.8110
Epoch: 194, Loss: 0.7530, Val: 0.7980, Test: 0.8110
Epoch: 195, Loss: 0.6606, Val: 0.8020, Test: 0.8130
Epoch: 196, Loss: 0.6922, Val: 0.7980, Test: 0.8140
Epoch: 197, Loss: 0.6684, Val: 0.8020, Test: 0.8180
Epoch: 198, Loss: 0.7610, Val: 0.8060, Test: 0.8230
Epoch: 199, Loss: 0.8197, Val: 0.8060, Test: 0.8260
Epoch: 200, Loss: 0.6495, Val: 0.8060, Test: 0.8250