Tutorial 11: DeepWalk and node2vec - Implementation details

Code:

Setup

[1]:
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
[2]:
from torch_geometric.nn import Node2Vec
import os.path as osp
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.datasets import Planetoid
from tqdm.notebook import tqdm
[3]:
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset)
data = dataset[0]
[4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data.edge_index, embedding_dim=128,
                 walk_length=20,                        # lenght of rw
                 context_size=10, walks_per_node=20,
                 num_negative_samples=1,
                 p=200, q=1,                             # bias parameters
                 sparse=True).to(device)

Random walks

The data loader

[5]:
loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
[6]:
for idx, (pos_rw, neg_rw) in enumerate(loader):
    print(idx, pos_rw.shape, neg_rw.shape)
0 torch.Size([28160, 10]) torch.Size([28160, 10])
1 torch.Size([28160, 10]) torch.Size([28160, 10])
2 torch.Size([28160, 10]) torch.Size([28160, 10])
3 torch.Size([28160, 10]) torch.Size([28160, 10])
4 torch.Size([28160, 10]) torch.Size([28160, 10])
5 torch.Size([28160, 10]) torch.Size([28160, 10])
6 torch.Size([28160, 10]) torch.Size([28160, 10])
7 torch.Size([28160, 10]) torch.Size([28160, 10])
8 torch.Size([28160, 10]) torch.Size([28160, 10])
9 torch.Size([28160, 10]) torch.Size([28160, 10])
10 torch.Size([28160, 10]) torch.Size([28160, 10])
11 torch.Size([28160, 10]) torch.Size([28160, 10])
12 torch.Size([28160, 10]) torch.Size([28160, 10])
13 torch.Size([28160, 10]) torch.Size([28160, 10])
14 torch.Size([28160, 10]) torch.Size([28160, 10])
15 torch.Size([28160, 10]) torch.Size([28160, 10])
16 torch.Size([28160, 10]) torch.Size([28160, 10])
17 torch.Size([28160, 10]) torch.Size([28160, 10])
18 torch.Size([28160, 10]) torch.Size([28160, 10])
19 torch.Size([28160, 10]) torch.Size([28160, 10])
20 torch.Size([28160, 10]) torch.Size([28160, 10])
21 torch.Size([4400, 10]) torch.Size([4400, 10])
[7]:
idx, (pos_rw, neg_rw) = next(enumerate(loader))
[8]:
idx
[8]:
0
[9]:
(pos_rw.shape, neg_rw.shape)
[9]:
(torch.Size([28160, 10]), torch.Size([28160, 10]))
[10]:
pos_rw
[10]:
tensor([[2571, 1386, 2570,  ..., 2045,  603, 1873],
        [1105,  578, 1974,  ...,  711, 2111,  711],
        [ 476,  306,  112,  ..., 1013,   69, 2189],
        ...,
        [2056, 1841, 2056,  ...,  751,  736,  751],
        [ 333, 1358, 1620,  ...,  562,  704, 2113],
        [ 702, 2069, 1377,  ..., 2388,  175, 2388]])
[11]:
neg_rw
[11]:
tensor([[2571,  792, 2686,  ...,   12, 1023, 2007],
        [1105, 2425,  283,  ..., 2594, 1697, 1755],
        [ 476,  415,  554,  ..., 2019, 1847,  369],
        ...,
        [ 409,  421, 2600,  ..., 2259, 2601,  270],
        [  67, 1305, 1259,  ..., 2100, 1527, 2139],
        [2404,   79, 2390,  ..., 1455,  248, 2223]])

Visualization

[12]:
import networkx as nx
edge_tuples = [tuple(x) for x in data.edge_index.numpy().transpose()]
G = nx.from_edgelist(edge_tuples)
pos = nx.spring_layout(G, center=[0.5, 0.5])
nx.set_node_attributes(G, pos, 'pos')
[13]:
nodelist = next(enumerate(loader))[1][0][0].tolist()
walk = nx.path_graph(len(nodelist))
nx.set_node_attributes(walk, {idx: pos[node_id] for idx, node_id in enumerate(nodelist)}, 'pos')

fig = plt.figure(figsize=(20, 10))
ax = fig.add_subplot(1, 2, 1)
nx.draw_networkx_nodes(G,
                       ax=ax,
                       pos=nx.get_node_attributes(G, 'pos'),
                       node_size=5,
                       alpha=0.3,
                       node_color='b')
nx.draw(walk,
        node_size=40,
        node_color='r',
        ax=ax,
        pos=nx.get_node_attributes(walk, 'pos'),
        width=2,
        edge_color='r')
ax = fig.add_subplot(1, 2, 2)
nx.draw(walk,
        node_size=40,
        node_color='r',
        ax=ax,
        pos=nx.get_node_attributes(walk, 'pos'),
        width=2,
        edge_color='r')
../../../_images/ipynbs_colabs_pyg_tutorial_project_Tutorial11_17_0.png

Training

Model definition

[14]:
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
                 context_size=10, walks_per_node=10,
                 num_negative_samples=1, p=1, q=1, sparse=True).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=4)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

Training function

[15]:
def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in tqdm(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

Test function

[16]:
data
[16]:
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
[17]:
data.y
[17]:
tensor([3, 4, 4,  ..., 3, 3, 3])
[18]:
data.y.unique()
[18]:
tensor([0, 1, 2, 3, 4, 5, 6])
[19]:
@torch.no_grad()
def test():
    model.eval()
    z = model()
    acc = model.test(z[data.train_mask], data.y[data.train_mask],
                     z[data.test_mask], data.y[data.test_mask],
                     max_iter=150)
    return acc

Training

[20]:
for epoch in range(1, 201):
    loss = train()
    acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')
Epoch: 01, Loss: 8.0718, Acc: 0.1460
Epoch: 02, Loss: 6.0427, Acc: 0.1700
Epoch: 03, Loss: 4.9415, Acc: 0.2070
Epoch: 04, Loss: 4.1228, Acc: 0.2420
Epoch: 05, Loss: 3.4709, Acc: 0.2740
Epoch: 06, Loss: 2.9523, Acc: 0.3100
Epoch: 07, Loss: 2.5429, Acc: 0.3370
Epoch: 08, Loss: 2.2041, Acc: 0.3760
Epoch: 09, Loss: 1.9471, Acc: 0.4130
Epoch: 10, Loss: 1.7284, Acc: 0.4430
Epoch: 11, Loss: 1.5604, Acc: 0.4610
Epoch: 12, Loss: 1.4219, Acc: 0.4850
Epoch: 13, Loss: 1.3128, Acc: 0.5050
Epoch: 14, Loss: 1.2242, Acc: 0.5230
Epoch: 15, Loss: 1.1563, Acc: 0.5340
Epoch: 16, Loss: 1.1002, Acc: 0.5490
Epoch: 17, Loss: 1.0553, Acc: 0.5660
Epoch: 18, Loss: 1.0215, Acc: 0.5840
Epoch: 19, Loss: 0.9915, Acc: 0.5950
Epoch: 20, Loss: 0.9686, Acc: 0.6130
Epoch: 21, Loss: 0.9492, Acc: 0.6250
Epoch: 22, Loss: 0.9332, Acc: 0.6340
Epoch: 23, Loss: 0.9207, Acc: 0.6440
Epoch: 24, Loss: 0.9103, Acc: 0.6570
Epoch: 25, Loss: 0.9001, Acc: 0.6620
Epoch: 26, Loss: 0.8923, Acc: 0.6650
Epoch: 27, Loss: 0.8840, Acc: 0.6700
Epoch: 28, Loss: 0.8794, Acc: 0.6700
Epoch: 29, Loss: 0.8735, Acc: 0.6750
Epoch: 30, Loss: 0.8683, Acc: 0.6800
Epoch: 31, Loss: 0.8653, Acc: 0.6750
Epoch: 32, Loss: 0.8623, Acc: 0.6790
Epoch: 33, Loss: 0.8581, Acc: 0.6840
Epoch: 34, Loss: 0.8556, Acc: 0.6880
Epoch: 35, Loss: 0.8526, Acc: 0.6940
Epoch: 36, Loss: 0.8505, Acc: 0.6870
Epoch: 37, Loss: 0.8496, Acc: 0.6970
Epoch: 38, Loss: 0.8468, Acc: 0.6980
Epoch: 39, Loss: 0.8455, Acc: 0.6970
Epoch: 40, Loss: 0.8439, Acc: 0.6920
Epoch: 41, Loss: 0.8434, Acc: 0.6950
Epoch: 42, Loss: 0.8420, Acc: 0.6960
Epoch: 43, Loss: 0.8397, Acc: 0.6970
Epoch: 44, Loss: 0.8397, Acc: 0.7010
Epoch: 45, Loss: 0.8377, Acc: 0.7020
Epoch: 46, Loss: 0.8364, Acc: 0.6950
Epoch: 47, Loss: 0.8354, Acc: 0.6990
Epoch: 48, Loss: 0.8353, Acc: 0.6930
Epoch: 49, Loss: 0.8341, Acc: 0.7030
Epoch: 50, Loss: 0.8340, Acc: 0.7030
Epoch: 51, Loss: 0.8334, Acc: 0.7010
Epoch: 52, Loss: 0.8327, Acc: 0.7050
Epoch: 53, Loss: 0.8334, Acc: 0.7040
Epoch: 54, Loss: 0.8317, Acc: 0.7030
Epoch: 55, Loss: 0.8319, Acc: 0.7110
Epoch: 56, Loss: 0.8311, Acc: 0.7060
Epoch: 57, Loss: 0.8303, Acc: 0.7040
Epoch: 58, Loss: 0.8290, Acc: 0.7010
Epoch: 59, Loss: 0.8291, Acc: 0.7020
Epoch: 60, Loss: 0.8287, Acc: 0.7000
Epoch: 61, Loss: 0.8278, Acc: 0.7040
Epoch: 62, Loss: 0.8281, Acc: 0.6990
Epoch: 63, Loss: 0.8285, Acc: 0.6970
Epoch: 64, Loss: 0.8285, Acc: 0.7000
Epoch: 65, Loss: 0.8275, Acc: 0.7050
Epoch: 66, Loss: 0.8281, Acc: 0.7030
Epoch: 67, Loss: 0.8263, Acc: 0.7040
Epoch: 68, Loss: 0.8275, Acc: 0.7030
Epoch: 69, Loss: 0.8271, Acc: 0.7040
Epoch: 70, Loss: 0.8278, Acc: 0.7020
Epoch: 71, Loss: 0.8270, Acc: 0.7040
Epoch: 72, Loss: 0.8257, Acc: 0.7070
Epoch: 73, Loss: 0.8269, Acc: 0.7090
Epoch: 74, Loss: 0.8239, Acc: 0.7160
Epoch: 75, Loss: 0.8255, Acc: 0.7040
Epoch: 76, Loss: 0.8261, Acc: 0.7040
Epoch: 77, Loss: 0.8247, Acc: 0.7020
Epoch: 78, Loss: 0.8260, Acc: 0.7040
Epoch: 79, Loss: 0.8252, Acc: 0.7010
Epoch: 80, Loss: 0.8261, Acc: 0.7100
Epoch: 81, Loss: 0.8256, Acc: 0.7110
Epoch: 82, Loss: 0.8258, Acc: 0.7090
Epoch: 83, Loss: 0.8237, Acc: 0.7130
Epoch: 84, Loss: 0.8250, Acc: 0.7100
Epoch: 85, Loss: 0.8250, Acc: 0.7120
Epoch: 86, Loss: 0.8249, Acc: 0.7120
Epoch: 87, Loss: 0.8248, Acc: 0.7130
Epoch: 88, Loss: 0.8244, Acc: 0.7070
Epoch: 89, Loss: 0.8241, Acc: 0.7100
Epoch: 90, Loss: 0.8243, Acc: 0.7030
Epoch: 91, Loss: 0.8248, Acc: 0.7110
Epoch: 92, Loss: 0.8244, Acc: 0.7110
Epoch: 93, Loss: 0.8254, Acc: 0.7040
Epoch: 94, Loss: 0.8256, Acc: 0.6940
Epoch: 95, Loss: 0.8248, Acc: 0.6940
Epoch: 96, Loss: 0.8237, Acc: 0.6890
Epoch: 97, Loss: 0.8251, Acc: 0.6930
Epoch: 98, Loss: 0.8250, Acc: 0.6920
Epoch: 99, Loss: 0.8246, Acc: 0.7040
Epoch: 100, Loss: 0.8245, Acc: 0.6910
Epoch: 101, Loss: 0.8237, Acc: 0.6990
Epoch: 102, Loss: 0.8237, Acc: 0.6980
Epoch: 103, Loss: 0.8226, Acc: 0.6980
Epoch: 104, Loss: 0.8243, Acc: 0.6820
Epoch: 105, Loss: 0.8245, Acc: 0.6870
Epoch: 106, Loss: 0.8245, Acc: 0.6910
Epoch: 107, Loss: 0.8254, Acc: 0.6940
Epoch: 108, Loss: 0.8256, Acc: 0.7000
Epoch: 109, Loss: 0.8237, Acc: 0.6980
Epoch: 110, Loss: 0.8236, Acc: 0.6880
Epoch: 111, Loss: 0.8244, Acc: 0.7010
Epoch: 112, Loss: 0.8238, Acc: 0.7070
Epoch: 113, Loss: 0.8236, Acc: 0.7080
Epoch: 114, Loss: 0.8248, Acc: 0.7110
Epoch: 115, Loss: 0.8234, Acc: 0.7090
Epoch: 116, Loss: 0.8246, Acc: 0.7100
Epoch: 117, Loss: 0.8244, Acc: 0.6980
Epoch: 118, Loss: 0.8249, Acc: 0.7080
Epoch: 119, Loss: 0.8254, Acc: 0.7140
Epoch: 120, Loss: 0.8236, Acc: 0.7170
Epoch: 121, Loss: 0.8235, Acc: 0.7200
Epoch: 122, Loss: 0.8243, Acc: 0.7160
Epoch: 123, Loss: 0.8251, Acc: 0.7110
Epoch: 124, Loss: 0.8236, Acc: 0.7160
Epoch: 125, Loss: 0.8244, Acc: 0.7000
Epoch: 126, Loss: 0.8249, Acc: 0.7100
Epoch: 127, Loss: 0.8237, Acc: 0.7120
Epoch: 128, Loss: 0.8256, Acc: 0.7010
Epoch: 129, Loss: 0.8236, Acc: 0.7070
Epoch: 130, Loss: 0.8241, Acc: 0.7050
Epoch: 131, Loss: 0.8238, Acc: 0.7020
Epoch: 132, Loss: 0.8248, Acc: 0.7120
Epoch: 133, Loss: 0.8241, Acc: 0.7120
Epoch: 134, Loss: 0.8239, Acc: 0.7210
Epoch: 135, Loss: 0.8235, Acc: 0.7230
Epoch: 136, Loss: 0.8252, Acc: 0.7050
Epoch: 137, Loss: 0.8252, Acc: 0.7090
Epoch: 138, Loss: 0.8236, Acc: 0.7060
Epoch: 139, Loss: 0.8236, Acc: 0.7060
Epoch: 140, Loss: 0.8235, Acc: 0.7060
Epoch: 141, Loss: 0.8245, Acc: 0.7000
Epoch: 142, Loss: 0.8249, Acc: 0.6900
Epoch: 143, Loss: 0.8239, Acc: 0.6980
Epoch: 144, Loss: 0.8246, Acc: 0.7000
Epoch: 145, Loss: 0.8248, Acc: 0.7010
Epoch: 146, Loss: 0.8235, Acc: 0.7010
Epoch: 147, Loss: 0.8238, Acc: 0.6990
Epoch: 148, Loss: 0.8234, Acc: 0.6920
Epoch: 149, Loss: 0.8241, Acc: 0.7020
Epoch: 150, Loss: 0.8253, Acc: 0.6980
Epoch: 151, Loss: 0.8250, Acc: 0.7030
Epoch: 152, Loss: 0.8249, Acc: 0.7040
Epoch: 153, Loss: 0.8262, Acc: 0.7010
Epoch: 154, Loss: 0.8248, Acc: 0.7090
Epoch: 155, Loss: 0.8241, Acc: 0.7070
Epoch: 156, Loss: 0.8244, Acc: 0.7100
Epoch: 157, Loss: 0.8255, Acc: 0.7050
Epoch: 158, Loss: 0.8251, Acc: 0.7090
Epoch: 159, Loss: 0.8245, Acc: 0.7060
Epoch: 160, Loss: 0.8243, Acc: 0.7050
Epoch: 161, Loss: 0.8247, Acc: 0.7130
Epoch: 162, Loss: 0.8247, Acc: 0.7090
Epoch: 163, Loss: 0.8243, Acc: 0.7140
Epoch: 164, Loss: 0.8242, Acc: 0.7190
Epoch: 165, Loss: 0.8252, Acc: 0.7170
Epoch: 166, Loss: 0.8250, Acc: 0.7110
Epoch: 167, Loss: 0.8252, Acc: 0.7120
Epoch: 168, Loss: 0.8254, Acc: 0.7130
Epoch: 169, Loss: 0.8250, Acc: 0.7170
Epoch: 170, Loss: 0.8247, Acc: 0.7190
Epoch: 171, Loss: 0.8236, Acc: 0.7130
Epoch: 172, Loss: 0.8246, Acc: 0.7070
Epoch: 173, Loss: 0.8243, Acc: 0.6980
Epoch: 174, Loss: 0.8249, Acc: 0.6980
Epoch: 175, Loss: 0.8253, Acc: 0.7070
Epoch: 176, Loss: 0.8250, Acc: 0.7160
Epoch: 177, Loss: 0.8245, Acc: 0.7080
Epoch: 178, Loss: 0.8251, Acc: 0.7030
Epoch: 179, Loss: 0.8255, Acc: 0.7060
Epoch: 180, Loss: 0.8247, Acc: 0.7260
Epoch: 181, Loss: 0.8257, Acc: 0.7220
Epoch: 182, Loss: 0.8251, Acc: 0.7120
Epoch: 183, Loss: 0.8243, Acc: 0.7050
Epoch: 184, Loss: 0.8242, Acc: 0.7100
Epoch: 185, Loss: 0.8255, Acc: 0.7160
Epoch: 186, Loss: 0.8256, Acc: 0.7130
Epoch: 187, Loss: 0.8249, Acc: 0.7040
Epoch: 188, Loss: 0.8255, Acc: 0.7030
Epoch: 189, Loss: 0.8249, Acc: 0.7150
Epoch: 190, Loss: 0.8259, Acc: 0.7050
Epoch: 191, Loss: 0.8257, Acc: 0.7060
Epoch: 192, Loss: 0.8254, Acc: 0.7100
Epoch: 193, Loss: 0.8246, Acc: 0.7080
Epoch: 194, Loss: 0.8255, Acc: 0.6980
Epoch: 195, Loss: 0.8241, Acc: 0.7050
Epoch: 196, Loss: 0.8250, Acc: 0.7150
Epoch: 197, Loss: 0.8246, Acc: 0.7080
Epoch: 198, Loss: 0.8234, Acc: 0.7140
Epoch: 199, Loss: 0.8251, Acc: 0.7210
Epoch: 200, Loss: 0.8255, Acc: 0.7220

Visualization

[21]:
@torch.no_grad()
def plot_points(colors):
    model.eval()
    z = model(torch.arange(data.num_nodes, device=device))
    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
    y = data.y.cpu().numpy()

    plt.figure(figsize=(8, 8))
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
    plt.axis('off')
    plt.show()

colors = [
    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535',
    '#ffd700'
]
plot_points(colors)
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
  FutureWarning,
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
  FutureWarning,
../../../_images/ipynbs_colabs_pyg_tutorial_project_Tutorial11_31_1.png