PyTorchのPERFORMANCE TUNING GUIDEの効果を確認してみる その2 「Fuse pointwise operations」
PyTorchには「PERFORMANCE TUNING GUIDE」という学習を速くするためのテクニック集があります。このドキュメントでは個々のテクニックでどれくらい速くなるか具体的な数値が示されていないので、それを確認するということをここ最近やっています。この記事はそのシリーズの第二弾として、「Fuse pointwise operations」を試してみたまとめです。
ちなみに、測定するときにいろいろ気を付けないといけないポイントがあったので、Fuse pointwise operationsのために利用したtorch.jit.script
の謎現象で困る人が減るように、それについても後半で説明していきます。
第一弾の「parameter.grad = Noneを使う」というのもありますので、PyTorchの高速化に興味がある方はそちらも合わせてご覧ください。
目次
Fuse pointwise operationsとは?
elementwiseの加算や乗算、sin()
, cos()
, sigmoid()
などなど、行列やベクトルの要素単位で実行される演算をまとめてpointwise operationsと呼ぶときがあります。これらの演算は一つの演算にかかる時間は非常に短いため、GPUのような関数1回の実行のオーバーヘッドやメモリアクセスのオーバーヘッドが大きい演算器では計算量のわりに長い計算時間がかかってしまいます。
このようなメモリアクセスや関数の実行のオーバーヘッドを削減する工夫として、複数の独立した演算を一つの関数にまとめる(fuse)という方法が良く用いられます。
PyTorchでも演算をまとめる仕組みがあります。その中でもpointwise operationsをfuseする仕組みとしてよく例で用いられるのが torch.jit.script
です。
今回はこのtorch.jit.script
によって、どれくらいfuseしたpointwise operationsが速くなるのかを確認していきます。
実際に効果を測定してみる
torch.jit.script
でfuseするとどれくらい速くなるのか?を測定するための環境と実際に用いたコードは以下の通りです。
- 実行環境:Google Colab
- GPU: Tesla T4
- PyTorch: 1.8.1
- torchvision: 0.9.1
- 測定に使ったnotebook: https://github.com/shu65/pyorch_performance_tuning_guide_examples/blob/main/Fuse_pointwise_operations.ipynb
また、今回測定に利用した関数は「PERFORMANCE TUNING GUIDE」で示されていたGELUです。実装自体は単純で、以下の通りです。
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / 1.41421))
また、今回はGPUのだけでなく念のためCPUも測定しました。
測定した結果は以下の通りです。
平均実行時間 (sec.) | デフォルトとの速度比 | |
CPU デフォルト | 0.106 | 1.00 |
CPU torch.jit.scriptあり | 0.105 | 1.00 |
平均実行時間 (sec.) | デフォルトとの速度比 | |
GPU デフォルト | 0.00356 | 1.00 |
GPU torch.jit.scriptあり | 0.000789 | 4.51 |
CPUのほうはあまり期待してなかったですが、予想通りほぼ変わらずという結果でした。一方、GPUのほうは劇的に速度が変化し、今回のGULEの例では4.5倍速くなることが確認できました。個人的にはtorch.jit.script
で速くなることはあまりないようなイメージだったので、シンプルなFuse pointwise operationsならちゃんと速くなるというのがわかって少し感動してます。
torch.jit.script を使った実行時間測定の注意点
さて、この記事を書くにあたってかなり苦労したので、その苦労話もちゃんと書いておこうと思います。この分量の内容の記事なら数時間で実験して書けるだろうと当初は思っていたのですが、torch.jit.script
の謎現象に悩まされて実験がちゃんと安定して取れるようになるまで、実は数日かかりました。なので、torch.jit.script
を使った計算時間測定の注意点をまとめておきます。
1. GPUの計算時間を正しく測定する
以前自分で「PyTorchでGPUの計算時間を正しく計測する」という記事を書きましたが、恥ずかしながら最初は正しく測定するのを忘れていました。なので、自戒も込めて何度も書きますが、GPUの計算時間を測定するときは注意してください。
2. torch.jit.scriptの1回目の実行はオーバーヘッドが大きいので無視する
torch.jit.script
は名前の通りJITなので、1回目の実行時はオーバーヘッドが大きいです。このため、1回も含めていて、かつ、少ない実行回数で平均を取るとtorch.jit.scriptを使っているのに速くなっていないというような状態になります。このため、ちゃんと測定する場合は1回目の実行は別にするようにするとよいかと思います。
3. 入力のTensorのshapeやdeviceが違う場合はtorch.jit.script()の実行前にキャッシュをクリアする
今回の測定で気が付くのに苦労した点がこれです。PyTorch 1.8.1現在、torch.jit.script()
は一度関数オブジェクトをtorch.jit.script化したあと、2回目以降はこの部分をスキップするためにキャッシュしています。このため、全く別のshapeやdeviceのTensorを入力に使う場合はtorch.jit.script()
を実行する前にキャッシュをクリアしておかないと、本来はJITを使って速くなるはずなのにキャッシュに残ったものがそのまま使われて全然速くならないという現象が発生します。
今回の測定に用いたnotebookではCPUを測定したあとGPUの測定をしています。このため、何もしていないとCPUでJITが走っているので、その後、いくら入力をGPUにしていてもGPU用のJITが走らず、torch.jit.script
をGPUで使っているのに全然速くならないという状態になります。
これを回避するために以下のようにキャッシュのクリアしてからtorch.jit.script化して測定を行うようにしています。
torch.jit._state._jit_function_overload_caching.clear()
torch.jit._state._jit_caching_layer.clear()
scripted_gelu = torch.jit.script(gelu)
ちなみにちゃんと最適化が走っているか確認する際はtorch.jit.last_executed_optimized_graph()
で直前の関数の実行時のグラフが出力できるので、JITが走っているはずの1回目の実行で「prim::profile」というものが出てきているか確認してください。現状のPyTorchのデフォルトだと最初の1回目はプロファイル測定のためにこのようなIRが挿入されるようになっています。
最後に
PyTorchには「PERFORMANCE TUNING GUIDE」の「Fuse pointwise operations」を試したときのまとめを書きました。個人的にはtorch.jit.script()
を使う際の注意点がいろいろわかってかなり勉強になりました。他にもまだまだ試したい高速化テクニックがあるので、試した際はまたこうしたまとめ記事を書こうと思います。