您的位置 首页 PyTorch 教程

图神经网络(GNN)教程 – 用 PyTorch 和 PyTorch Geometric 实现 Graph Neural Networks

PyTorch入门实战教程

图神经网络(Graph Neural Networks)最近是越来越火,很多问题都可以用图神经网络找到新的解决方法。 今天我们就来看怎么用 PyTorch 和 PyTorch Geometric (PyG) 实现图神经网络。PyG 是一款号称比 DGL 快 14 倍的基于 PyTorch 的几何深度学习框架,可以简单方便的实现图神经网络。

另外,在 PyG 的 Github 页面上,有一个现在已经在 PyG 里实现的网络列表,大家可以参考一下。

系统需求

要求至少安装 PyTorch 1.2.0 版本。

PyTorch Geometric 基础知识

这一部分我们介绍一下 PyG 的基础知识,主要包括 torch_geometric.data 和 torch_geometric.nn 部分。另外,还会介绍怎么设计自己的 Message Passing Layer。

Data 类

torch_geometric.data 包里有一个 Data 类,通过 Data 类我们可以很方便的创建图结构。

定义一个图结构,需要以下变量:

  1. 每个节点(node)的 features
  2. 边的连接关系或者边的 features

我们以下面的图结构为例,看看怎么用 Data 类创建图结构:

在上图中,一共有四个节点 \(v_1, v_2, v_3, v_4\),其中每个节点都有一个二维的特征向量和一个标签 \(y\)。这个特征向量和标签可以用 FloatTensor 来表示:

图的连接关系(边)可以用 COO 格式表示。COO 格式的维度是 [2, num_edges],其中第一个列表是所有边上起始节点的 index,第二个列表是对应边上目标节点的 index:

注意上面的数据里定义边的顺序是无关紧要的,这个数据仅仅用来计算邻接矩阵用的,比如上面的定义和下面的定义是等价的:

综上所述,我们可以这样定义上面的图结构:

Dataset

PyG 里有两种数据集类型:InMemoryDataset 和 Dataset,第一种适用于可以全部放进内存中的小数据集,第二种则适用于不能一次性放进内存中的大数据集。我们以 InMemoryDataset 为例。

InMemoryDataset 中有下列四个函数需要我们实现:

raw_file_names()

返回一个包含所有未处理过的数据文件的文件名的列表。

起始也可以返回一个空列表,然后在后面要说的 process() 函数里再定义。

processed_file_names()

返回一个包含所有处理过的数据文件的文件名的列表。

download()

如果在数据加载前需要先下载,则在这里定义下载过程,下载到 self.raw_dir 中定义的文件夹位置。

如果不需要下载,返回 pass 即可。

process()

这是最重要的一个函数,我们需要在这个函数里把数据处理成一个 Data 对象。下面是官方的一个示例代码:

本文接下来会介绍如何用 RecSys Challenge 2015 的数据创建一个自定义数据集。

DataLoader

这个类可以帮助我们将数据按 batch 传给 model,定义的方法如下,需要制定 batch_sizedataset

每个 loader 的循环都返回一个 Batch 对象:

Batch 相比 Data 对象多了一个 batch 参数,告诉我们这个 batch 里都包含哪些 nodes,便于计算。

MessagePassing

Message Passing 是图网络中学习 node embedding 的重要方法。点击这里查看官方文档对这个的详细说明,我们接下来也将基于官方的说明来讲解。

Message Passing 的公示如下:

其中,\(x\) 表示表格节点的 embedding,\(e\) 表示边的特征,\(\phi\) 表示 message 函数,\(□\) 表示聚合 aggregation 函数,\(\gamma\) 表示 update 函数。上标表示层的 index,比如说,当 k = 1 时,\(x\) 则表示所有输入网络的图结构的数据。

下面是每个函数的介绍:

propagate(edge_index, size=None, **kwargs)

这个函数最终会调用 messageupdate 函数。

message(**kwargs)

这个函数定义了对于每个节点对 \((x_i, x_j)\),怎样生成信息(message)。

update(aggr_out, **kwargs)

这个函数利用聚合好的信息(message)更新每个节点的 embedding。

示例:SageConv

我们来看看怎样实现论文 “Inductive Representation Learning on Large Graphs” 中的 SageConv 层。SageConv 的 Message Passing 定义如下:

聚合函数(aggregation)我们用最大池化(max pooling),这样上述公示中的 AGGREGATE 可以写为:

上述公式中,对于每个邻居节点,都和一个 weighted matrix 相乘,并且加上一个 bias,传给一个激活函数。相关代码如下:

对于 update 方法,我们需要聚合更新每个节点的 embedding,然后加上权重矩阵和偏置:

综上所述,SageConv 层的定于方法如下:

示例:RecSys Challenge 2015

RecSys Challenge 2015 是一个挑战赛,主要目的是创建一个 session-based recommender system。主要任务有两个:

  1. 预测经过一系列的点击后,是否会产生购买行为。
  2. 预测购买的商品。

数据下载地址在这里。数据主要包含两部分:yoochoose-clicks.dat 点击数据, 和 yoochoose-buys.dat 购买行为数据。

点击数据的示例如下:

购买行为数据示例如下:

数据预处理 Preprocessing

下载好数据后,我们先进行一些预处理:

处理后的数据示例如下:

因为数据太多,我们随机进行取样以方便讲解:

取样的数据统计如下:

另外,为获取标签,即对于某个特定的 session,是否产生了购买行为,我们只需要检查文件 yoochoose-clicks.dat 中的 session_id 是否在文件 yoochoose-buys.dat 中出现即可:

结果如下:

创建 Dataset

这里我们将预处理过的数据创建成为 Dataset 对象。对于每个 session,里面的每个商品(item)看作一个节点,因此每个 session 里所有的商品组成一个图。

首先,我们将数据集按照 session_id 进行分组,分组过程中 item_id 也要被重新编码,因为对于每个图,每个节点的 index 应该从 0 开始:

然后我们对数据集进行随机排序,分成 training, validation 和 testing 三个子数据集:

创建图网络(Graph Neural Network)

下列代码过程参考了官方的一个示例并做了适当的修改:

训练

训练过程中,我们使用 Adam 优化器,学习率 0.005,损失函数是 BCE:

Validation

这个数据集非常的不平衡,因为大多数的 session 里没有购买行为。也就是说,如果一个模型将所有的结果都预测为 false,也能达到 90% 的准确率。因此,这里我们不使用 accuracy 作为评测标准,而是使用 Area Under Curve (AUC):

训练结果

下面是模型训练了 1 epoch 的结果:

可以看到,我们用非常少的数据,在只训练了 1 个 epoch 的情况下,测试集的 AUC 也能达到 0.73。如果用更多的数据训练,应该可以达到更好的结果。

本站微信群、QQ群(三群号 726282629):

PyTorch入门实战教程

发表评论

电子邮件地址不会被公开。 必填项已用*标注

返回顶部