Tutorial 5: Aggregation

In this tutorial we will override the aggregation method of the GIN convolution module of Pytorch Geometric implementing the following methods:

  • Principal Neighborhood Aggregation (PNA)

  • Learning Aggregation Functions (LAF)

[ ]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.12.1+cu113
     |████████████████████████████████| 7.9 MB 5.0 MB/s
     |████████████████████████████████| 3.5 MB 4.9 MB/s
  Building wheel for torch-geometric (setup.py) ... done
[ ]:
import torch
torch.manual_seed(42)
<torch._C.Generator at 0x7f1e2cb8ca30>

Message Passing Class

[ ]:
from torch_geometric.nn import MessagePassing
[ ]:
dir(MessagePassing)
['T_destination',
 '__annotations__',
 '__call__',
 '__check_input__',
 '__class__',
 '__collect__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lift__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__set_size__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_call_impl',
 '_get_backward_hooks',
 '_get_name',
 '_load_from_state_dict',
 '_maybe_warn_non_full_backward_hook',
 '_named_members',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_version',
 'add_module',
 'aggregate',
 'apply',
 'bfloat16',
 'buffers',
 'children',
 'cpu',
 'cuda',
 'double',
 'dump_patches',
 'edge_update',
 'edge_updater',
 'eval',
 'explain',
 'explain_message',
 'extra_repr',
 'float',
 'forward',
 'get_buffer',
 'get_extra_state',
 'get_parameter',
 'get_submodule',
 'half',
 'ipu',
 'jittable',
 'load_state_dict',
 'message',
 'message_and_aggregate',
 'modules',
 'named_buffers',
 'named_children',
 'named_modules',
 'named_parameters',
 'parameters',
 'propagate',
 'register_aggregate_forward_hook',
 'register_aggregate_forward_pre_hook',
 'register_backward_hook',
 'register_buffer',
 'register_edge_update_forward_hook',
 'register_edge_update_forward_pre_hook',
 'register_forward_hook',
 'register_forward_pre_hook',
 'register_full_backward_hook',
 'register_load_state_dict_post_hook',
 'register_message_and_aggregate_forward_hook',
 'register_message_and_aggregate_forward_pre_hook',
 'register_message_forward_hook',
 'register_message_forward_pre_hook',
 'register_module',
 'register_parameter',
 'register_propagate_forward_hook',
 'register_propagate_forward_pre_hook',
 'requires_grad_',
 'set_extra_state',
 'share_memory',
 'special_args',
 'state_dict',
 'to',
 'to_empty',
 'train',
 'type',
 'update',
 'xpu',
 'zero_grad']

We are interested in the aggregate method, or, if you are using a sparse adjacency matrix, in the message_and_aggregate method. Convolutional classes in PyG extend MessagePassing, we construct our custom convoutional class extending GINConv.

Scatter operation in aggregate:

5fbdf952512749d6844eeff74170dbeb

[ ]:
from torch.nn import Parameter, Module, Sigmoid
import torch
import torch_scatter
import torch.nn.functional as F

class AbstractLAFLayer(Module):
    def __init__(self, **kwargs):
        super(AbstractLAFLayer, self).__init__()
        assert 'units' in kwargs or 'weights' in kwargs
        if 'device' in kwargs.keys():
            self.device = kwargs['device']
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ngpus = torch.cuda.device_count()

        if 'kernel_initializer' in kwargs.keys():
            assert kwargs['kernel_initializer'] in [
                'random_normal',
                'glorot_normal',
                'he_normal',
                'random_uniform',
                'glorot_uniform',
                'he_uniform']
            self.kernel_initializer = kwargs['kernel_initializer']
        else:
            self.kernel_initializer = 'random_normal'

        if 'weights' in kwargs.keys():
            self.weights = Parameter(kwargs['weights'].to(self.device), \
                                     requires_grad=True)
            self.units = self.weights.shape[1]
        else:
            self.units = kwargs['units']
            params = torch.empty(12, self.units, device=self.device)
            if self.kernel_initializer == 'random_normal':
                torch.nn.init.normal_(params)
            elif self.kernel_initializer == 'glorot_normal':
                torch.nn.init.xavier_normal_(params)
            elif self.kernel_initializer == 'he_normal':
                torch.nn.init.kaiming_normal_(params)
            elif self.kernel_initializer == 'random_uniform':
                torch.nn.init.uniform_(params)
            elif self.kernel_initializer == 'glorot_uniform':
                torch.nn.init.xavier_uniform_(params)
            elif self.kernel_initializer == 'he_uniform':
                torch.nn.init.kaiming_uniform_(params)
            self.weights = Parameter(params, \
                                     requires_grad=True)
        e = torch.tensor([1,-1,1,-1], dtype=torch.float32, device=self.device)
        self.e = Parameter(e, requires_grad=False)
        num_idx = torch.tensor([1,1,0,0], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.num_idx = Parameter(num_idx, requires_grad=False)
        den_idx = torch.tensor([0,0,1,1], dtype=torch.float32, device=self.device).\
                                view(1,1,-1,1)
        self.den_idx = Parameter(den_idx, requires_grad=False)


class LAFLayer(AbstractLAFLayer):
    def __init__(self, eps=1e-7, **kwargs):
        super(LAFLayer, self).__init__(**kwargs)
        self.eps = eps

    def forward(self, data, index, dim=0, **kwargs):
        eps = self.eps
        sup = 1.0 - eps
        e = self.e

        x = torch.clamp(data, eps, sup)
        x = torch.unsqueeze(x, -1)
        e = e.view(1,1,-1)

        exps = (1. - e)/2. + x*e
        exps = torch.unsqueeze(exps, -1)
        exps = torch.pow(exps, torch.relu(self.weights[0:4]))

        scatter = torch_scatter.scatter_add(exps, index.view(-1), dim=dim)
        scatter = torch.clamp(scatter, eps)

        sqrt = torch.pow(scatter, torch.relu(self.weights[4:8]))
        alpha_beta = self.weights[8:12].view(1,1,4,-1)
        terms = sqrt * alpha_beta

        num = torch.sum(terms * self.num_idx, dim=2)
        den = torch.sum(terms * self.den_idx, dim=2)

        multiplier = 2.0*torch.clamp(torch.sign(den), min=0.0) - 1.0

        den = torch.where((den < eps) & (den > -eps), multiplier*eps, den)

        res = num / den
        return res
[ ]:
from torch_geometric.nn import GINConv
from torch.nn import Linear

LAF Aggregation Module

158a8e68187b435c88c6a6f099568445

[ ]:
class GINLAFConv(GINConv):
    def __init__(self, nn, units=1, node_dim=32, **kwargs):
        super(GINLAFConv, self).__init__(nn, **kwargs)
        self.laf = LAFLayer(units=units, kernel_initializer='random_uniform')
        self.mlp = torch.nn.Linear(node_dim*units, node_dim)
        self.dim = node_dim
        self.units = units

    def aggregate(self, inputs, index):
        x = torch.sigmoid(inputs)
        x = self.laf(x, index)
        x = x.view((-1, self.dim * self.units))
        x = self.mlp(x)
        return x

PNA Aggregation

ddfc044276934f5e8db3fa273a73ed13

[ ]:
class GINPNAConv(GINConv):
    def __init__(self, nn, node_dim=32, **kwargs):
        super(GINPNAConv, self).__init__(nn, **kwargs)
        self.mlp = torch.nn.Linear(node_dim*12, node_dim)
        self.delta = 2.5749

    def aggregate(self, inputs, index):
        sums = torch_scatter.scatter_add(inputs, index, dim=0)
        maxs = torch_scatter.scatter_max(inputs, index, dim=0)[0]
        means = torch_scatter.scatter_mean(inputs, index, dim=0)
        var = torch.relu(torch_scatter.scatter_mean(inputs ** 2, index, dim=0) - means ** 2)

        aggrs = [sums, maxs, means, var]
        c_idx = index.bincount().float().view(-1, 1)
        l_idx = torch.log(c_idx + 1.)

        amplification_scaler = [c_idx / self.delta * a for a in aggrs]
        attenuation_scaler = [self.delta / c_idx * a for a in aggrs]
        combinations = torch.cat(aggrs+ amplification_scaler+ attenuation_scaler, dim=1)
        x = self.mlp(combinations)

        return x

Test the new classes

[ ]:
from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool
import torch_scatter
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import os.path as osp

[ ]:
path = osp.join('./', 'data', 'TU')
dataset = TUDataset(path, name='MUTAG').shuffle()
test_dataset = dataset[:len(dataset) // 10]
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset, batch_size=128)
train_loader = DataLoader(train_dataset, batch_size=128)
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting data/TU/MUTAG/MUTAG.zip
Processing...
Done!
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
[ ]:
class LAFNet(torch.nn.Module):
    def __init__(self):
        super(LAFNet, self).__init__()

        num_features = dataset.num_features
        dim = 32
        units = 3

        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
        self.conv1 = GINLAFConv(nn1, units=units, node_dim=num_features)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = GINLAFConv(nn2, units=units, node_dim=dim)
        self.bn2 = torch.nn.BatchNorm1d(dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = GINLAFConv(nn3, units=units, node_dim=dim)
        self.bn3 = torch.nn.BatchNorm1d(dim)

        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv4 = GINLAFConv(nn4, units=units, node_dim=dim)
        self.bn4 = torch.nn.BatchNorm1d(dim)

        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv5 = GINLAFConv(nn5, units=units, node_dim=dim)
        self.bn5 = torch.nn.BatchNorm1d(dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)

[ ]:
class PNANet(torch.nn.Module):
    def __init__(self):
        super(PNANet, self).__init__()

        num_features = dataset.num_features
        dim = 32

        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
        self.conv1 = GINPNAConv(nn1, node_dim=num_features)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = GINPNAConv(nn2, node_dim=dim)
        self.bn2 = torch.nn.BatchNorm1d(dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = GINPNAConv(nn3, node_dim=dim)
        self.bn3 = torch.nn.BatchNorm1d(dim)

        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv4 = GINPNAConv(nn4, node_dim=dim)
        self.bn4 = torch.nn.BatchNorm1d(dim)

        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv5 = GINPNAConv(nn5, node_dim=dim)
        self.bn5 = torch.nn.BatchNorm1d(dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)
[ ]:
class GINNet(torch.nn.Module):
    def __init__(self):
        super(GINNet, self).__init__()

        num_features = dataset.num_features
        dim = 32

        nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
        self.conv1 = GINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv2 = GINConv(nn2)
        self.bn2 = torch.nn.BatchNorm1d(dim)

        nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv3 = GINConv(nn3)
        self.bn3 = torch.nn.BatchNorm1d(dim)

        nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv4 = GINConv(nn4)
        self.bn4 = torch.nn.BatchNorm1d(dim)

        nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
        self.conv5 = GINConv(nn5)
        self.bn5 = torch.nn.BatchNorm1d(dim)

        self.fc1 = Linear(dim, dim)
        self.fc2 = Linear(dim, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "LAF"
if net == "LAF":
    model = LAFNet().to(device)
elif net == "PNA":
    model = PNANet().to(device)
elif net == "GIN":
    GINNet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()

    if epoch == 51:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.5 * param_group['lr']

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_dataset)


def test(loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)


for epoch in range(1, 101):
    train_loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                       train_acc, test_acc))
Epoch: 001, Train Loss: 0.8650472, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 002, Train Loss: 0.7599028, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 003, Train Loss: 0.8972220, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 004, Train Loss: 0.6185434, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 005, Train Loss: 0.6005230, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 006, Train Loss: 0.5512175, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 007, Train Loss: 0.5332195, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 008, Train Loss: 0.5134736, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 009, Train Loss: 0.4718563, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 010, Train Loss: 0.4698687, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 011, Train Loss: 0.4464772, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 012, Train Loss: 0.4414581, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 013, Train Loss: 0.4507246, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 014, Train Loss: 0.4593955, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 015, Train Loss: 0.4188018, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 016, Train Loss: 0.3976869, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 017, Train Loss: 0.4080824, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 018, Train Loss: 0.4642429, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 019, Train Loss: 0.3612275, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 020, Train Loss: 0.3702769, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 021, Train Loss: 0.3751319, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 022, Train Loss: 0.3421200, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 023, Train Loss: 0.3866120, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 024, Train Loss: 0.3492658, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 025, Train Loss: 0.3558516, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 026, Train Loss: 0.3727173, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 027, Train Loss: 0.3154053, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 028, Train Loss: 0.3201577, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 029, Train Loss: 0.3272583, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 030, Train Loss: 0.3112883, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 031, Train Loss: 0.3407421, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 032, Train Loss: 0.2899052, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 033, Train Loss: 0.3580514, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 034, Train Loss: 0.2954516, Train Acc: 0.6764706, Test Acc: 0.6666667
Epoch: 035, Train Loss: 0.2975145, Train Acc: 0.6941176, Test Acc: 0.7777778
Epoch: 036, Train Loss: 0.3173143, Train Acc: 0.7235294, Test Acc: 0.8333333
Epoch: 037, Train Loss: 0.2602276, Train Acc: 0.7000000, Test Acc: 0.7777778
Epoch: 038, Train Loss: 0.2713226, Train Acc: 0.7117647, Test Acc: 0.8333333
Epoch: 039, Train Loss: 0.2706065, Train Acc: 0.6941176, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.2786444, Train Acc: 0.6941176, Test Acc: 0.6666667
Epoch: 041, Train Loss: 0.2833781, Train Acc: 0.7117647, Test Acc: 0.7777778
Epoch: 042, Train Loss: 0.2924134, Train Acc: 0.7117647, Test Acc: 0.7777778
Epoch: 043, Train Loss: 0.2812036, Train Acc: 0.8235294, Test Acc: 0.9444444
Epoch: 044, Train Loss: 0.2887369, Train Acc: 0.7941176, Test Acc: 0.8888889
Epoch: 045, Train Loss: 0.2367283, Train Acc: 0.7529412, Test Acc: 0.7777778
Epoch: 046, Train Loss: 0.2811927, Train Acc: 0.7941176, Test Acc: 0.8888889
Epoch: 047, Train Loss: 0.2571158, Train Acc: 0.7588235, Test Acc: 0.7777778
Epoch: 048, Train Loss: 0.2812370, Train Acc: 0.5588235, Test Acc: 0.6111111
Epoch: 049, Train Loss: 0.2664493, Train Acc: 0.8294118, Test Acc: 0.7222222
Epoch: 050, Train Loss: 0.2698024, Train Acc: 0.8117647, Test Acc: 0.8333333
Epoch: 051, Train Loss: 0.2950443, Train Acc: 0.7647059, Test Acc: 0.6666667
Epoch: 052, Train Loss: 0.2750369, Train Acc: 0.7352941, Test Acc: 0.6666667
Epoch: 053, Train Loss: 0.2846459, Train Acc: 0.7882353, Test Acc: 0.7777778
Epoch: 054, Train Loss: 0.2428172, Train Acc: 0.7588235, Test Acc: 0.8888889
Epoch: 055, Train Loss: 0.2569554, Train Acc: 0.7647059, Test Acc: 0.8888889
Epoch: 056, Train Loss: 0.2893244, Train Acc: 0.7647059, Test Acc: 0.8888889
Epoch: 057, Train Loss: 0.2695741, Train Acc: 0.8058824, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.2683432, Train Acc: 0.8235294, Test Acc: 0.8333333
Epoch: 059, Train Loss: 0.2483253, Train Acc: 0.8294118, Test Acc: 0.8333333
Epoch: 060, Train Loss: 0.2359067, Train Acc: 0.8117647, Test Acc: 0.8333333
Epoch: 061, Train Loss: 0.2581255, Train Acc: 0.8235294, Test Acc: 0.8333333
Epoch: 062, Train Loss: 0.2320385, Train Acc: 0.8882353, Test Acc: 0.8333333
Epoch: 063, Train Loss: 0.2304887, Train Acc: 0.8647059, Test Acc: 0.8333333
Epoch: 064, Train Loss: 0.2351827, Train Acc: 0.8176471, Test Acc: 0.8333333
Epoch: 065, Train Loss: 0.2371133, Train Acc: 0.7647059, Test Acc: 0.7777778
Epoch: 066, Train Loss: 0.2476480, Train Acc: 0.7705882, Test Acc: 0.7777778
Epoch: 067, Train Loss: 0.2557588, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 068, Train Loss: 0.2158999, Train Acc: 0.8352941, Test Acc: 0.7222222
Epoch: 069, Train Loss: 0.2353542, Train Acc: 0.7882353, Test Acc: 0.6111111
Epoch: 070, Train Loss: 0.2403484, Train Acc: 0.7705882, Test Acc: 0.5555556
Epoch: 071, Train Loss: 0.2292482, Train Acc: 0.7764706, Test Acc: 0.7222222
Epoch: 072, Train Loss: 0.2588242, Train Acc: 0.8000000, Test Acc: 0.7222222
Epoch: 073, Train Loss: 0.2330211, Train Acc: 0.8882353, Test Acc: 0.6666667
Epoch: 074, Train Loss: 0.2573530, Train Acc: 0.8235294, Test Acc: 0.7777778
Epoch: 075, Train Loss: 0.2250361, Train Acc: 0.8941176, Test Acc: 0.7777778
Epoch: 076, Train Loss: 0.2089147, Train Acc: 0.9058824, Test Acc: 0.7222222
Epoch: 077, Train Loss: 0.2145577, Train Acc: 0.9117647, Test Acc: 0.7777778
Epoch: 078, Train Loss: 0.2313216, Train Acc: 0.8705882, Test Acc: 0.7222222
Epoch: 079, Train Loss: 0.2348573, Train Acc: 0.8470588, Test Acc: 0.6666667
Epoch: 080, Train Loss: 0.2337190, Train Acc: 0.8176471, Test Acc: 0.6666667
Epoch: 081, Train Loss: 0.2247560, Train Acc: 0.7352941, Test Acc: 0.6111111
Epoch: 082, Train Loss: 0.2352007, Train Acc: 0.6352941, Test Acc: 0.5000000
Epoch: 083, Train Loss: 0.2404233, Train Acc: 0.7411765, Test Acc: 0.6111111
Epoch: 084, Train Loss: 0.2203369, Train Acc: 0.8000000, Test Acc: 0.6666667
Epoch: 085, Train Loss: 0.2096777, Train Acc: 0.8647059, Test Acc: 0.7777778
Epoch: 086, Train Loss: 0.2133037, Train Acc: 0.8411765, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.1921520, Train Acc: 0.8411765, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.2259413, Train Acc: 0.9117647, Test Acc: 0.7777778
Epoch: 089, Train Loss: 0.2021636, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 090, Train Loss: 0.1980333, Train Acc: 0.9294118, Test Acc: 0.7777778
Epoch: 091, Train Loss: 0.2226479, Train Acc: 0.8882353, Test Acc: 0.7222222
Epoch: 092, Train Loss: 0.2319808, Train Acc: 0.8294118, Test Acc: 0.6666667
Epoch: 093, Train Loss: 0.2365441, Train Acc: 0.8117647, Test Acc: 0.6666667
Epoch: 094, Train Loss: 0.2389078, Train Acc: 0.8764706, Test Acc: 0.6666667
Epoch: 095, Train Loss: 0.1904467, Train Acc: 0.9000000, Test Acc: 0.7777778
Epoch: 096, Train Loss: 0.2198207, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 097, Train Loss: 0.2253925, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 098, Train Loss: 0.1958107, Train Acc: 0.8352941, Test Acc: 0.7777778
Epoch: 099, Train Loss: 0.2277206, Train Acc: 0.8294118, Test Acc: 0.7777778
Epoch: 100, Train Loss: 0.2341601, Train Acc: 0.7058824, Test Acc: 0.7222222
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "PNA"
if net == "LAF":
    model = LAFNet().to(device)
elif net == "PNA":
    model = PNANet().to(device)
elif net == "GIN":
    GINNet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


for epoch in range(1, 101):
    train_loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                       train_acc, test_acc))
Epoch: 001, Train Loss: 1.3497391, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 002, Train Loss: 0.8684199, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 003, Train Loss: 0.7279473, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 004, Train Loss: 0.7402998, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 005, Train Loss: 0.7657306, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 006, Train Loss: 0.7453549, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 007, Train Loss: 0.5307669, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 008, Train Loss: 0.4403997, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 009, Train Loss: 0.4544508, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 010, Train Loss: 0.4488042, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 011, Train Loss: 0.4297011, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 012, Train Loss: 0.3781979, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 013, Train Loss: 0.4004532, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 014, Train Loss: 0.3619624, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 015, Train Loss: 0.3303704, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 016, Train Loss: 0.3489703, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 017, Train Loss: 0.2879844, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 018, Train Loss: 0.2992957, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 019, Train Loss: 0.2941008, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 020, Train Loss: 0.2742822, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 021, Train Loss: 0.2828649, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 022, Train Loss: 0.2448842, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 023, Train Loss: 0.2426275, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 024, Train Loss: 0.2311836, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 025, Train Loss: 0.1974013, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 026, Train Loss: 0.1801775, Train Acc: 0.6647059, Test Acc: 0.6111111
Epoch: 027, Train Loss: 0.2027490, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 028, Train Loss: 0.1469845, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 029, Train Loss: 0.1597498, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 030, Train Loss: 0.1537053, Train Acc: 0.6823529, Test Acc: 0.6666667
Epoch: 031, Train Loss: 0.1487810, Train Acc: 0.7117647, Test Acc: 0.6666667
Epoch: 032, Train Loss: 0.1435155, Train Acc: 0.7176471, Test Acc: 0.7222222
Epoch: 033, Train Loss: 0.1291888, Train Acc: 0.7176471, Test Acc: 0.7222222
Epoch: 034, Train Loss: 0.1226045, Train Acc: 0.7470588, Test Acc: 0.7222222
Epoch: 035, Train Loss: 0.1137753, Train Acc: 0.7588235, Test Acc: 0.6666667
Epoch: 036, Train Loss: 0.1119650, Train Acc: 0.8058824, Test Acc: 0.7222222
Epoch: 037, Train Loss: 0.1085063, Train Acc: 0.8235294, Test Acc: 0.7222222
Epoch: 038, Train Loss: 0.1228803, Train Acc: 0.8352941, Test Acc: 0.7222222
Epoch: 039, Train Loss: 0.0784458, Train Acc: 0.8411765, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.0854303, Train Acc: 0.8764706, Test Acc: 0.8333333
Epoch: 041, Train Loss: 0.1073735, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 042, Train Loss: 0.0834263, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 043, Train Loss: 0.0607265, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 044, Train Loss: 0.0718378, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 045, Train Loss: 0.0689468, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 046, Train Loss: 0.0382091, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 047, Train Loss: 0.0879735, Train Acc: 0.9352941, Test Acc: 0.7777778
Epoch: 048, Train Loss: 0.0623109, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 049, Train Loss: 0.0908670, Train Acc: 0.9058824, Test Acc: 0.7777778
Epoch: 050, Train Loss: 0.0681597, Train Acc: 0.9000000, Test Acc: 0.8333333
Epoch: 051, Train Loss: 0.0693567, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 052, Train Loss: 0.0478872, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 053, Train Loss: 0.0506401, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 054, Train Loss: 0.0308294, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 055, Train Loss: 0.0326454, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 056, Train Loss: 0.0250327, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 057, Train Loss: 0.0349234, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.0380354, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 059, Train Loss: 0.0238508, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 060, Train Loss: 0.0260360, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 061, Train Loss: 0.0156592, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 062, Train Loss: 0.0397532, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 063, Train Loss: 0.0147181, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 064, Train Loss: 0.0263763, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 065, Train Loss: 0.0219646, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 066, Train Loss: 0.0174770, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 067, Train Loss: 0.0233532, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 068, Train Loss: 0.0329869, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 069, Train Loss: 0.0267206, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 070, Train Loss: 0.0195115, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 071, Train Loss: 0.0263306, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 072, Train Loss: 0.0161402, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 073, Train Loss: 0.0138596, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 074, Train Loss: 0.0176732, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 075, Train Loss: 0.0140430, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 076, Train Loss: 0.0223834, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 077, Train Loss: 0.0151263, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 078, Train Loss: 0.0113194, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 079, Train Loss: 0.0178343, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 080, Train Loss: 0.0132977, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 081, Train Loss: 0.0099823, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 082, Train Loss: 0.0103535, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 083, Train Loss: 0.0049559, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 084, Train Loss: 0.0115411, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 085, Train Loss: 0.0132454, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 086, Train Loss: 0.0139688, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.0082945, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.0144088, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 089, Train Loss: 0.0116169, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 090, Train Loss: 0.0115055, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 091, Train Loss: 0.0044924, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 092, Train Loss: 0.0073951, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 093, Train Loss: 0.0098597, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 094, Train Loss: 0.0071243, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 095, Train Loss: 0.0084314, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 096, Train Loss: 0.0116200, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 097, Train Loss: 0.0109158, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 098, Train Loss: 0.0088956, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 099, Train Loss: 0.0098493, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 100, Train Loss: 0.0082795, Train Acc: 0.9529412, Test Acc: 0.8333333
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "GIN"
if net == "LAF":
    model = LAFNet().to(device)
elif net == "PNA":
    model = PNANet().to(device)
elif net == "GIN":
    GINNet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


for epoch in range(1, 101):
    train_loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print('Epoch: {:03d}, Train Loss: {:.7f}, '
          'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
                                                       train_acc, test_acc))
Epoch: 001, Train Loss: 0.1006957, Train Acc: 0.8705882, Test Acc: 0.8888889
Epoch: 002, Train Loss: 0.5604971, Train Acc: 0.9235294, Test Acc: 0.7222222
Epoch: 003, Train Loss: 0.4205183, Train Acc: 0.8588235, Test Acc: 0.6111111
Epoch: 004, Train Loss: 0.1701126, Train Acc: 0.8176471, Test Acc: 0.6111111
Epoch: 005, Train Loss: 0.1697284, Train Acc: 0.8411765, Test Acc: 0.7777778
Epoch: 006, Train Loss: 0.1547725, Train Acc: 0.8176471, Test Acc: 0.7222222
Epoch: 007, Train Loss: 0.1122712, Train Acc: 0.7529412, Test Acc: 0.7777778
Epoch: 008, Train Loss: 0.1182288, Train Acc: 0.7470588, Test Acc: 0.7222222
Epoch: 009, Train Loss: 0.1329069, Train Acc: 0.9117647, Test Acc: 0.8333333
Epoch: 010, Train Loss: 0.1019645, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 011, Train Loss: 0.0771604, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 012, Train Loss: 0.0847688, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 013, Train Loss: 0.0684039, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 014, Train Loss: 0.0651711, Train Acc: 0.9470588, Test Acc: 0.7777778
Epoch: 015, Train Loss: 0.0518811, Train Acc: 0.9588235, Test Acc: 0.7777778
Epoch: 016, Train Loss: 0.0677568, Train Acc: 0.9705882, Test Acc: 0.7777778
Epoch: 017, Train Loss: 0.0393111, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 018, Train Loss: 0.0367973, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 019, Train Loss: 0.0366539, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 020, Train Loss: 0.0568547, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 021, Train Loss: 0.0447065, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 022, Train Loss: 0.0352459, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 023, Train Loss: 0.0249647, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 024, Train Loss: 0.0145648, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 025, Train Loss: 0.0205373, Train Acc: 0.9823529, Test Acc: 0.7777778
Epoch: 026, Train Loss: 0.0161799, Train Acc: 0.9647059, Test Acc: 0.7777778
Epoch: 027, Train Loss: 0.0125704, Train Acc: 0.9588235, Test Acc: 0.7777778
Epoch: 028, Train Loss: 0.0112206, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 029, Train Loss: 0.0095180, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 030, Train Loss: 0.0139799, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 031, Train Loss: 0.0133235, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 032, Train Loss: 0.0116212, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 033, Train Loss: 0.0074385, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 034, Train Loss: 0.0063465, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 035, Train Loss: 0.0093689, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 036, Train Loss: 0.0118155, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 037, Train Loss: 0.0166583, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 038, Train Loss: 0.0108432, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 039, Train Loss: 0.0092749, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.0081560, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 041, Train Loss: 0.0145553, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 042, Train Loss: 0.0051442, Train Acc: 0.9764706, Test Acc: 0.7222222
Epoch: 043, Train Loss: 0.0128016, Train Acc: 0.9823529, Test Acc: 0.7222222
Epoch: 044, Train Loss: 0.0083365, Train Acc: 0.9823529, Test Acc: 0.6666667
Epoch: 045, Train Loss: 0.0449262, Train Acc: 0.9470588, Test Acc: 0.7222222
Epoch: 046, Train Loss: 0.1241174, Train Acc: 0.9352941, Test Acc: 0.7222222
Epoch: 047, Train Loss: 0.0577372, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 048, Train Loss: 0.0158565, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 049, Train Loss: 0.0264535, Train Acc: 0.9647059, Test Acc: 0.7222222
Epoch: 050, Train Loss: 0.0493911, Train Acc: 0.9470588, Test Acc: 0.7777778
Epoch: 051, Train Loss: 0.0307947, Train Acc: 0.9647059, Test Acc: 0.7777778
Epoch: 052, Train Loss: 0.0502689, Train Acc: 0.9705882, Test Acc: 0.7777778
Epoch: 053, Train Loss: 0.0220471, Train Acc: 0.9411765, Test Acc: 0.7777778
Epoch: 054, Train Loss: 0.0271277, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 055, Train Loss: 0.0193326, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 056, Train Loss: 0.0085988, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 057, Train Loss: 0.0200223, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.0065113, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 059, Train Loss: 0.0118877, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 060, Train Loss: 0.0138910, Train Acc: 0.9470588, Test Acc: 0.8888889
Epoch: 061, Train Loss: 0.0099632, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 062, Train Loss: 0.0104697, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 063, Train Loss: 0.0117506, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 064, Train Loss: 0.0129451, Train Acc: 0.9470588, Test Acc: 0.8888889
Epoch: 065, Train Loss: 0.0049019, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 066, Train Loss: 0.0059774, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 067, Train Loss: 0.0029972, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 068, Train Loss: 0.0070204, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 069, Train Loss: 0.0058905, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 070, Train Loss: 0.0122656, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 071, Train Loss: 0.0080602, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 072, Train Loss: 0.0048456, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 073, Train Loss: 0.0024820, Train Acc: 0.9647059, Test Acc: 0.8888889
Epoch: 074, Train Loss: 0.0066221, Train Acc: 0.9647059, Test Acc: 0.8888889
Epoch: 075, Train Loss: 0.0054791, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 076, Train Loss: 0.0041069, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 077, Train Loss: 0.0039224, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 078, Train Loss: 0.0038528, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 079, Train Loss: 0.0026217, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 080, Train Loss: 0.0029335, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 081, Train Loss: 0.0046612, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 082, Train Loss: 0.0039182, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 083, Train Loss: 0.0048634, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 084, Train Loss: 0.0021048, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 085, Train Loss: 0.0055200, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 086, Train Loss: 0.0030238, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.0048484, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.0029367, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 089, Train Loss: 0.0014313, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 090, Train Loss: 0.0061374, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 091, Train Loss: 0.0055641, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 092, Train Loss: 0.0049570, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 093, Train Loss: 0.0008855, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 094, Train Loss: 0.0033897, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 095, Train Loss: 0.0020113, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 096, Train Loss: 0.0019971, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 097, Train Loss: 0.0073497, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 098, Train Loss: 0.0037542, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 099, Train Loss: 0.0011019, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 100, Train Loss: 0.0037228, Train Acc: 0.9647059, Test Acc: 0.8888889