[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# !pip install -q captum

# Helper function for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
1.11.0

6. Explaining GNN Model Predictions using Captum

In this tutorial we demonstrate how to apply feature attribution methods to graphs. Specifically, we try to find the most important edges for each instance prediction.

We use the Mutagenicity dataset from TUDatasets. This dataset consists of 4337 molecule graphs where the task is to predict the molecule mutagenicity.

Loading the dataset

We load the dataset and use 10% of the data as the test split.

[2]:
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset

path = '.'
dataset = TUDataset(path, name='Mutagenicity').shuffle()
test_dataset = dataset[:len(dataset) // 10]
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset, batch_size=128)
train_loader = DataLoader(train_dataset, batch_size=128)

Visualizing the data

We define some utility functions for visualizing the molecules and draw a random molecule.

[3]:
import networkx as nx
import numpy as np

from torch_geometric.utils import to_networkx


def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
    g = g.copy().to_undirected()
    node_labels = {}
    for u, data in g.nodes(data=True):
        node_labels[u] = data['name']
    pos = nx.planar_layout(g)
    pos = nx.spring_layout(g, pos=pos)
    if edge_mask is None:
        edge_color = 'black'
        widths = None
    else:
        edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
        widths = [x * 10 for x in edge_color]
    nx.draw(g, pos=pos, labels=node_labels, width=widths,
            edge_color=edge_color, edge_cmap=plt.cm.Blues,
            node_color='azure')

    if draw_edge_labels and edge_mask is not None:
        edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}
        nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
                                    font_color='red')
    plt.show()


def to_molecule(data):
    ATOM_MAP = ['C', 'O', 'Cl', 'H', 'N', 'F',
                'Br', 'S', 'P', 'I', 'Na', 'K', 'Li', 'Ca']
    g = to_networkx(data, node_attrs=['x'])
    for u, data in g.nodes(data=True):
        data['name'] = ATOM_MAP[data['x'].index(1.0)]
        del data['x']
    return g

Sample visualization

We sample a single molecule from train_dataset and visualize it

[4]:
import random

data = random.choice([t for t in train_dataset])
mol = to_molecule(data)
plt.figure(figsize=(10, 5))
draw_molecule(mol)
../../../_images/ipynbs_colabs_pytorch_geometric_6_GNN_Explanation_7_0.png

Training the model

In the next section, we train a GNN model with 5 convolution layers. We use GraphConv which supports edge_weight as a parameter. Many convolution layers in Pytorch Geometric supoort this argument.

Define the model

[5]:
import torch
from torch.nn import Linear
import torch.nn.functional as F

from torch_geometric.nn import GraphConv, global_add_pool

class Net(torch.nn.Module):
    def __init__(self, dim):
        super(Net, self).__init__()

        num_features = dataset.num_features
        self.dim = dim

        self.conv1 = GraphConv(num_features, dim)
        self.conv2 = GraphConv(dim, dim)
        self.conv3 = GraphConv(dim, dim)
        self.conv4 = GraphConv(dim, dim)
        self.conv5 = GraphConv(dim, dim)

        self.lin1 = Linear(dim, dim)
        self.lin2 = Linear(dim, dataset.num_classes)

    def forward(self, x, edge_index, batch, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = self.conv2(x, edge_index, edge_weight).relu()
        x = self.conv3(x, edge_index, edge_weight).relu()
        x = self.conv4(x, edge_index, edge_weight).relu()
        x = self.conv5(x, edge_index, edge_weight).relu()
        x = global_add_pool(x, batch)
        x = self.lin1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

Define train and test functions

[6]:
def train(epoch):
    model.train()

    if epoch == 51:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.5 * param_group['lr']

    loss_all = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_dataset)


def test(loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

Train the model for 100 epochs

The accuracy should be around 80% in the end

[7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dim=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 101):
    loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
          f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Loss: 0.7109, Train Acc: 0.6539, Test Acc: 0.6651
Epoch: 002, Loss: 0.6226, Train Acc: 0.6862, Test Acc: 0.6859
Epoch: 003, Loss: 0.5847, Train Acc: 0.7111, Test Acc: 0.6975
Epoch: 004, Loss: 0.5810, Train Acc: 0.7241, Test Acc: 0.7182
Epoch: 005, Loss: 0.5675, Train Acc: 0.7364, Test Acc: 0.7206
Epoch: 006, Loss: 0.5437, Train Acc: 0.7400, Test Acc: 0.7113
Epoch: 007, Loss: 0.5419, Train Acc: 0.7564, Test Acc: 0.7413
Epoch: 008, Loss: 0.5263, Train Acc: 0.7515, Test Acc: 0.7159
Epoch: 009, Loss: 0.5201, Train Acc: 0.7731, Test Acc: 0.7344
Epoch: 010, Loss: 0.5145, Train Acc: 0.7861, Test Acc: 0.7760
Epoch: 011, Loss: 0.5038, Train Acc: 0.7828, Test Acc: 0.7529
Epoch: 012, Loss: 0.4997, Train Acc: 0.7897, Test Acc: 0.7644
Epoch: 013, Loss: 0.4861, Train Acc: 0.7951, Test Acc: 0.7598
Epoch: 014, Loss: 0.4728, Train Acc: 0.7966, Test Acc: 0.7575
Epoch: 015, Loss: 0.4796, Train Acc: 0.7969, Test Acc: 0.7621
Epoch: 016, Loss: 0.4708, Train Acc: 0.8015, Test Acc: 0.7691
Epoch: 017, Loss: 0.4733, Train Acc: 0.8058, Test Acc: 0.7760
Epoch: 018, Loss: 0.4614, Train Acc: 0.8074, Test Acc: 0.7714
Epoch: 019, Loss: 0.4577, Train Acc: 0.8135, Test Acc: 0.7783
Epoch: 020, Loss: 0.4407, Train Acc: 0.8243, Test Acc: 0.7945
Epoch: 021, Loss: 0.4447, Train Acc: 0.8215, Test Acc: 0.7760
Epoch: 022, Loss: 0.4353, Train Acc: 0.8291, Test Acc: 0.7852
Epoch: 023, Loss: 0.4267, Train Acc: 0.8212, Test Acc: 0.7829
Epoch: 024, Loss: 0.4377, Train Acc: 0.8281, Test Acc: 0.7898
Epoch: 025, Loss: 0.4326, Train Acc: 0.8320, Test Acc: 0.7875
Epoch: 026, Loss: 0.4175, Train Acc: 0.8353, Test Acc: 0.7875
Epoch: 027, Loss: 0.4142, Train Acc: 0.8381, Test Acc: 0.8037
Epoch: 028, Loss: 0.4050, Train Acc: 0.8394, Test Acc: 0.7921
Epoch: 029, Loss: 0.4101, Train Acc: 0.8394, Test Acc: 0.7875
Epoch: 030, Loss: 0.4103, Train Acc: 0.8389, Test Acc: 0.7945
Epoch: 031, Loss: 0.4033, Train Acc: 0.8412, Test Acc: 0.8014
Epoch: 032, Loss: 0.3946, Train Acc: 0.8448, Test Acc: 0.7852
Epoch: 033, Loss: 0.3957, Train Acc: 0.8502, Test Acc: 0.8083
Epoch: 034, Loss: 0.3920, Train Acc: 0.8397, Test Acc: 0.7991
Epoch: 035, Loss: 0.3902, Train Acc: 0.8445, Test Acc: 0.7991
Epoch: 036, Loss: 0.3948, Train Acc: 0.8404, Test Acc: 0.7945
Epoch: 037, Loss: 0.3874, Train Acc: 0.8550, Test Acc: 0.8014
Epoch: 038, Loss: 0.3776, Train Acc: 0.8543, Test Acc: 0.8106
Epoch: 039, Loss: 0.3715, Train Acc: 0.8581, Test Acc: 0.8060
Epoch: 040, Loss: 0.3705, Train Acc: 0.8514, Test Acc: 0.8060
Epoch: 041, Loss: 0.3660, Train Acc: 0.8589, Test Acc: 0.8106
Epoch: 042, Loss: 0.3659, Train Acc: 0.8624, Test Acc: 0.8037
Epoch: 043, Loss: 0.3685, Train Acc: 0.8617, Test Acc: 0.7991
Epoch: 044, Loss: 0.3612, Train Acc: 0.8586, Test Acc: 0.7991
Epoch: 045, Loss: 0.3624, Train Acc: 0.8624, Test Acc: 0.8245
Epoch: 046, Loss: 0.3590, Train Acc: 0.8599, Test Acc: 0.8060
Epoch: 047, Loss: 0.3436, Train Acc: 0.8686, Test Acc: 0.8037
Epoch: 048, Loss: 0.3425, Train Acc: 0.8642, Test Acc: 0.8222
Epoch: 049, Loss: 0.3469, Train Acc: 0.8665, Test Acc: 0.8176
Epoch: 050, Loss: 0.3345, Train Acc: 0.8699, Test Acc: 0.7991
Epoch: 051, Loss: 0.3262, Train Acc: 0.8753, Test Acc: 0.8129
Epoch: 052, Loss: 0.3198, Train Acc: 0.8755, Test Acc: 0.8222
Epoch: 053, Loss: 0.3167, Train Acc: 0.8735, Test Acc: 0.8245
Epoch: 054, Loss: 0.3173, Train Acc: 0.8730, Test Acc: 0.8337
Epoch: 055, Loss: 0.3125, Train Acc: 0.8747, Test Acc: 0.8314
Epoch: 056, Loss: 0.3090, Train Acc: 0.8783, Test Acc: 0.8268
Epoch: 057, Loss: 0.3097, Train Acc: 0.8837, Test Acc: 0.8360
Epoch: 058, Loss: 0.3018, Train Acc: 0.8863, Test Acc: 0.8337
Epoch: 059, Loss: 0.3048, Train Acc: 0.8863, Test Acc: 0.8268
Epoch: 060, Loss: 0.3075, Train Acc: 0.8822, Test Acc: 0.8360
Epoch: 061, Loss: 0.3042, Train Acc: 0.8860, Test Acc: 0.8360
Epoch: 062, Loss: 0.2940, Train Acc: 0.8873, Test Acc: 0.8383
Epoch: 063, Loss: 0.3028, Train Acc: 0.8806, Test Acc: 0.8291
Epoch: 064, Loss: 0.2970, Train Acc: 0.8809, Test Acc: 0.8245
Epoch: 065, Loss: 0.2973, Train Acc: 0.8904, Test Acc: 0.8199
Epoch: 066, Loss: 0.2938, Train Acc: 0.8888, Test Acc: 0.8314
Epoch: 067, Loss: 0.2940, Train Acc: 0.8883, Test Acc: 0.8268
Epoch: 068, Loss: 0.2930, Train Acc: 0.8899, Test Acc: 0.8360
Epoch: 069, Loss: 0.2825, Train Acc: 0.8911, Test Acc: 0.8337
Epoch: 070, Loss: 0.2839, Train Acc: 0.8934, Test Acc: 0.8199
Epoch: 071, Loss: 0.2817, Train Acc: 0.8911, Test Acc: 0.8360
Epoch: 072, Loss: 0.2835, Train Acc: 0.8932, Test Acc: 0.8314
Epoch: 073, Loss: 0.2915, Train Acc: 0.8927, Test Acc: 0.8222
Epoch: 074, Loss: 0.2787, Train Acc: 0.8965, Test Acc: 0.8245
Epoch: 075, Loss: 0.2788, Train Acc: 0.8934, Test Acc: 0.8268
Epoch: 076, Loss: 0.2808, Train Acc: 0.8909, Test Acc: 0.8291
Epoch: 077, Loss: 0.2809, Train Acc: 0.8832, Test Acc: 0.8222
Epoch: 078, Loss: 0.2886, Train Acc: 0.8947, Test Acc: 0.8314
Epoch: 079, Loss: 0.2717, Train Acc: 0.8922, Test Acc: 0.8199
Epoch: 080, Loss: 0.2808, Train Acc: 0.8952, Test Acc: 0.8314
Epoch: 081, Loss: 0.2700, Train Acc: 0.8888, Test Acc: 0.8360
Epoch: 082, Loss: 0.2687, Train Acc: 0.8970, Test Acc: 0.8268
Epoch: 083, Loss: 0.2753, Train Acc: 0.8909, Test Acc: 0.8268
Epoch: 084, Loss: 0.2717, Train Acc: 0.8927, Test Acc: 0.8291
Epoch: 085, Loss: 0.2644, Train Acc: 0.8991, Test Acc: 0.8337
Epoch: 086, Loss: 0.2717, Train Acc: 0.9011, Test Acc: 0.8337
Epoch: 087, Loss: 0.2650, Train Acc: 0.8957, Test Acc: 0.8291
Epoch: 088, Loss: 0.2662, Train Acc: 0.8993, Test Acc: 0.8337
Epoch: 089, Loss: 0.2629, Train Acc: 0.8942, Test Acc: 0.8291
Epoch: 090, Loss: 0.2681, Train Acc: 0.8998, Test Acc: 0.8360
Epoch: 091, Loss: 0.2920, Train Acc: 0.8950, Test Acc: 0.8337
Epoch: 092, Loss: 0.2706, Train Acc: 0.8855, Test Acc: 0.8199
Epoch: 093, Loss: 0.2753, Train Acc: 0.8916, Test Acc: 0.8245
Epoch: 094, Loss: 0.2655, Train Acc: 0.9024, Test Acc: 0.8337
Epoch: 095, Loss: 0.2573, Train Acc: 0.9039, Test Acc: 0.8245
Epoch: 096, Loss: 0.2532, Train Acc: 0.9024, Test Acc: 0.8245
Epoch: 097, Loss: 0.2481, Train Acc: 0.9022, Test Acc: 0.8337
Epoch: 098, Loss: 0.2493, Train Acc: 0.9022, Test Acc: 0.8222
Epoch: 099, Loss: 0.2471, Train Acc: 0.8973, Test Acc: 0.8152
Epoch: 100, Loss: 0.2455, Train Acc: 0.9088, Test Acc: 0.8222

Explaining the predictions

Now we look at two popular attribution methods. First, we calculate the gradient of the output with respect to the edge weights \(w_{e_i}\). Edge weights are initially one for all edges. For the saliency method, we use the absolute value of the gradient as the attribution value for each edge:

\[Attribution_{e_i} = |\frac{\partial F(x)}{\partial w_{e_i}}|\]

Where \(x\) is the input and \(F(x)\) is the output of the GNN model on input \(x\).

For Integrated Gradients method, we interpolate between the current input and a baseline input where the weight of all edges is zero and accumulate the gradient values for each edge:

\[Attribution_{e_i} = \int_{\alpha =0}^1 \frac{\partial F(x_{\alpha)}}{\partial w_{e_i}} d\alpha\]

Where \(x_{\alpha}\) is the same as the original input graph but the weight of all edges is set to \(\alpha\). Integrated Gradients complete formulation is more complicated but since our initial edge weights are equal to one and the baseline is zero, it can be simplified to the formulation above. You can read more about this method here. Of course, this can not be calculated directly and is approximated by a discrete sum.

We use the captum library for calculating the attribution values. We define the model_forward function which calculates the batch argument assuming that we are only explaining a single graph at a time.

[8]:
from captum.attr import Saliency, IntegratedGradients

def model_forward(edge_mask, data):
    batch = torch.zeros(data.x.shape[0], dtype=int).to(device)
    out = model(data.x, data.edge_index, batch, edge_mask)
    return out


def explain(method, data, target=0):
    input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(device)
    if method == 'ig':
        ig = IntegratedGradients(model_forward)
        mask = ig.attribute(input_mask, target=target,
                            additional_forward_args=(data,),
                            internal_batch_size=data.edge_index.shape[1])
    elif method == 'saliency':
        saliency = Saliency(model_forward)
        mask = saliency.attribute(input_mask, target=target,
                                  additional_forward_args=(data,))
    else:
        raise Exception('Unknown explanation method')

    edge_mask = np.abs(mask.cpu().detach().numpy())
    if edge_mask.max() > 0:  # avoid division by zero
        edge_mask = edge_mask / edge_mask.max()
    return edge_mask

Finally we take a random sample from the test dataset and run the explanation methods. For a simpler visualization, we make the graph undirected and merge the explanations of each edge in both directions.

It is known that NO2 substructure makes the molecules mutagenic in many cases and you can verify this by the model explanations.

Mutagenic molecules have label 0 in this dataset and we only sample from those molecules but you can change the code and see the explanations for the other class as well.

In this visualization, edge colors and thickness represent the importance. You can also see the numeric value by passing draw_edge_labels to draw_molecule function.

As you can see Integrated Gradients tend to create more accurate explanations.

[9]:
import random
from collections import defaultdict

def aggregate_edge_directions(edge_mask, data):
    edge_mask_dict = defaultdict(float)
    for val, u, v in list(zip(edge_mask, *data.edge_index)):
        u, v = u.item(), v.item()
        if u > v:
            u, v = v, u
        edge_mask_dict[(u, v)] += val
    return edge_mask_dict


data = random.choice([t for t in test_dataset if not t.y.item()])
data.to(device)
mol = to_molecule(data)

for title, method in [('Integrated Gradients', 'ig'), ('Saliency', 'saliency')]:
    edge_mask = explain(method, data, target=0)
    edge_mask_dict = aggregate_edge_directions(edge_mask, data)
    plt.figure(figsize=(10, 5))
    plt.title(title)
    draw_molecule(mol, edge_mask_dict)
../../../_images/ipynbs_colabs_pytorch_geometric_6_GNN_Explanation_18_0.png
../../../_images/ipynbs_colabs_pytorch_geometric_6_GNN_Explanation_18_1.png