PyTorch Geometricを使ってVariational Graph Auto-Encodersを作って学習してみる
目次
はじめに
最近読んだ論文にVariational Graph Auto-Encoders (VGAE) を使ったモデルがあったので、自分でもやってみようと思い、作ってみました。本日はそのまとめになります。
本日紹介する使うコードは以下のものです。
また、このコード自体、以下のPyTorch Geometricのexampleのコードとほぼ同じです。
このblog記事ではVGAEで必要な機能がPyTorch Geometricでどう実装されているのかわからなかった部分がいくつかあるのでその部分を解説していく記事になります。
PyTorch Geometricとは
PyTorch GeometricはPyTorchを使って構築されたGraph Neural Network向けのライブラリになります。
GitHubのURLは以下の通りです。
https://github.com/pyg-team/pytorch_geometric
最新のPyTorchやCUDAにもちゃんと対応しており、Graph Neural Networkで必要な基本的な機能はそろっている印象です。
Variational Graph Auto-Encoders (VGAE)とは
VGAEはVariational Auto-Encoder (VAE) というモデルをGraphデータ向けに拡張したモデルです。VAEの説明を始めるとそれだけですごく長くなりますので、今回はVGAEを実装するうえで必要なところだけ紹介します。
VAEは以下のようにEncoderとDecoderという二つのモデルを組み合わせたモデルになります。
このうち、EncoderとDecoderは以下のようなモデルになります。
- Encoder: 入力Xを受け取って潜在変数Zの分布のパラメータを出力する
- Decoder: 潜在変数Zを受け取って入力Xを再構成する
VAEで重要なのがEncoderの部分と潜在変数Zのサンプリングの部分です。この潜在変数Zの分布が標準正規分布という仮定のもと学習させながら、Encoderで潜在変数Zの分布のパラメータを出力し、その分布のパラメータを使って潜在変数ZをサンプリングしてDecoderに渡すということを行います。
このVAEをGraph データに拡張するためにVGAEはEncoderとDecoderを以下のようなモデルにしています。
- Encoder: ノードの特徴ベクトルXと隣接行列Aを入力として受け取り、潜在変数Zの分布のパラメータを出力する
- Decoder: 潜在変数Zを受け取り隣接行列Aを再構築する
図にすると以下のようなイメージです。
VGAEとVAEとの違いはEncoderでグラフの情報であるノード情報と隣接行列を受け取れるようにしたことと、Decoderが出力するものが隣接行列になることです。
VGAEをPyTorch Geometricを使って実装する
VGAEの概略を説明したので次は実際に実装を紹介していきます。まずはEncoderであるVariationalGCNEncoder
から見ていきます。EncoderではPyTorch Geometricに実装されている GCNConv
を使って実装します。
from torch_geometric.nn import GCNConv
class VariationalGCNEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv_mu = GCNConv(2 * out_channels, out_channels)
self.conv_logstd = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
GCNConv
はノードのインプットのチャンネル数、アウトプットのチャネル数を引数にとってインスタンスを作ります。そしてforwardではノードのtensor x
と隣接行列のかわりにどのノード同士がつながっているか?を示すedge_index
を渡します。GCNConv
の中身についてはドキュメントに詳しく書かれているのでそちらをご覧ください。
https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv
このEncoderではVGAEの概要でも説明した通り、潜在変数の分布のパラメータを返します。ここではガウス分布の平均を表すmuと標準偏差にlogを適用したlogstdを返しています。
モデルの実装としてはあとはPyTorch Geometricで実装されているVGAE
というクラスに渡せば終わりになります。
from torch_geometric.nn import VGAE
model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
ただ、これだとさすがに初見だと何が何だかわからなかったので、少し説明します。
まず、Decoderについてです。DecoderはVGAE
のデフォルトではInnerProductDecoder
というものが使われます。これはVGAEの元論文でも使われていたDecoderの実装で、エッジの両端のノードに対応する潜在変数の各要素の積を取って総和を取り、sigmoidを適用して0-1の値にして出力します。出力値が0-1の値になっているのでDecoderの出力値は計算に使った二つのノードの間にエッジがある確率とみることができます。
詳しくは以下のドキュメントをご覧ください。
また、ロス関数についてですが、VGAE
の中にVGAEで必要な以下の二つが実装されています。
recon_loss
: 潜在変数zとノード同士のつながりを示すpos_edge_indexを入力にとり、Decoderを利用して各エッジのある確率を計算、その確率に対してbinary cross entropyを計算してlossとして返す関数
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/autoencoder.html#GAE.recon_losskl_loss
: Encoderの出力したmuとlogstdを使って標準正規分布とのKLダイバージェンスを計算しlossとして返す関数
https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.VGAE.kl_loss
これを以下のように学習ループで利用して学習をおこないます。
for epoch in range(0, 400):
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x, train_data.edge_index)
recon_loss = model.recon_loss(z, train_data.pos_edge_label_index)
kl_loss = (1 / train_data.num_nodes) * model.kl_loss()
loss = recon_loss + kl_loss
loss.backward()
optimizer.step()
最後に上のコードではノード間にエッジがあるところの情報はtrain_data.pos_edge_label_index
で渡しているのですが、ノード間にエッジがないという情報はどこで渡しているか?ということについて説明します。
コードを読むと実はrecon_lossの中で自動的にエッジがないという情報を生成してそれを込みでロスが計算されています。具体的には以下の部分です。
ここで引数でneg_edge_index
がNone
のときは自動でエッジが存在しないノードのペアをサンプリングするという処理になっています。
以下です。その他の部分で気になるところがある場合は全体のコードを以下のところに置いてありますのでご覧ください。
終わりに
今回はPyTorch Geometricの練習として、VGAEを実装してみたのでまとめの記事を書きました。PyTorch Geometricを今回初めて使ったのですが、Graph Neural Networkに必要な基本的な機能はそろっていそうなので、今後もGraph Neural Networkを使う機会があれば使ってみようと思います。