JAXとPyTorch、どっちが速いのか検証してみた
高速化が趣味&仕事なので、最近よく目にするJAXの速度が気になってました。このため、今回は日ごろ使っているPyTorchと比較したので、その結果のまとめを紹介します。
目次
結論
結果だけ知りたい方が多いだろうと思ったので先に結論から書くと、私のPyTorch力では力及ばず、今回の検証では
JAXのほうがPyTorchの2.2倍速い
という結果でした。ここから詳しく評価について説明します。
評価方法
今回、JAXとPyTorchを比較するにあたり、この前紹介したSmooth Smith Watermanのコードを利用しました。Smooth Smith Watermanについて知りたいという方は以下の記事をご覧ください。
この記事で紹介したJAXコードは論文の著者が頑張って高速化した結果なため、十分最適化された結果であるという認識です。このため、今回はPyTorchのコードを私が作成し、測定を行いました。
今回の検証コードはここに置いてあります。
https://github.com/shu65/blog-jax-notebook/blob/main/Smooth_Smith_Waterman_PyTorch_vs_JAX.ipynb
今回は3パターン実装したので、それぞれについて順番に紹介します。
実行はGoogle Colab上で行いました。この際、使用したGPUやライブラリのバージョンは以下の通りです。
- GPU: K80
- CUDA: 11.2
- PyTorch: 1.10.0
- JAX: 0.2.21
また、Smooth Smith Watermanは2つの配列の最大長を100, 120とした64個の配列のペアを入力に与えて測定しました。今回は少し測定誤差が入ることも考慮して10回平均で比較します。
JAXのコードをそのままPyTorchにする
JAXのコードで利用されているアルゴリズムはPyTorchでも十分速くなるようにみえました。このため、まずはそのまま適用してみました。PyTorchのコードとしては以下の通りです。
class SwTorch(nn.Module):
def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
super(SwTorch, self).__init__()
self.unroll = unroll
self.NINF = torch.tensor(NINF, device=device)
self.device = device
def _make_mask(self, score_matrix, lengths):
a,b = score_matrix.shape
real_a = lengths[0]
real_b = lengths[1]
mask = (torch.arange(a, device=self.device) < real_a)[:,None] & (torch.arange(b, device=self.device) < real_b)[None,:]
return mask
def _rotate(self, score_matrix):
a,b = score_matrix.shape
n,m = (a+b-1),(a+b)//2
ar = torch.flip(torch.arange(a, device=self.device), [0])[:, None]
br = torch.arange(b, device=self.device)[None,:]
i,j = (br-ar)+(a-1),(ar+br)//2
rotated_score_matrix = torch.full([n,m], self.NINF, dtype=score_matrix.dtype, device=self.device)
rotated_score_matrix[i, j] = score_matrix
reverse_idx = (i, j)
return rotated_score_matrix, reverse_idx
def _step(self, prev, gap_cell_condition, rotated_score_matrix, gap, temp):
h2,h1 = prev # previous two rows of scoring (hij) mtx
h1_T = self._get_prev_gap_cell_score(
gap_cell_condition,
torch.nn.functional.pad(h1[:-1], [1,0], value=self.NINF),
torch.nn.functional.pad(h1[1:], [0,1], value=self.NINF),
)
a = h2 + rotated_score_matrix
g0 = h1 + gap
g1 = h1_T + gap
s = rotated_score_matrix
h0 = torch.stack([a, g0, g1, s], -1)
h0 = self._soft_maximum(h0, temp, -1)
return (h1,h0), h0
def _rotate_in_reverse(self, rotated_dp_matrix, reverse_idx):
return rotated_dp_matrix[reverse_idx]
def _logsumexp(self, y, axis):
y = torch.maximum(y,self.NINF)
return torch.logsumexp(y, axis=axis)
def _logsumexp_with_mask(self, y, axis, mask):
y = torch.maximum(y,self.NINF)
if axis is None:
return torch.max(y) + torch.log(torch.sum(mask * torch.exp(y - torch.max(y))))
else:
return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, axis, keepdims=True)[0]), axis=axis))
def _soft_maximum(self, x, temp, axis=None):
return temp*self._logsumexp(x/temp, axis)
def _soft_maximum_with_mask(self, x, temp, mask, axis=None):
return temp*self._logsumexp_with_mask(x/temp, axis, mask)
def _get_prev_gap_cell_score(self, cond, true, false):
return cond*true + (1-cond)*false
def forward(self, score_matrix, lengths, gap=0, temp=1.0):
mask = self._make_mask(score_matrix, lengths)
masked_score_matrix = score_matrix + self.NINF * (~mask)
rotated_score_matrix, reverse_idx = self._rotate(masked_score_matrix)
a,b = score_matrix.shape
n,m = rotated_score_matrix.shape
gap_cell_condition = (torch.arange(n, device=self.device)+a%2)%2
prev = (torch.full((m,), self.NINF, device=self.device), torch.full((m,), self.NINF, device=self.device))
rotated_hij = [None for _ in range(n)]
for i in range(n):
prev, h = self._step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
rotated_hij[i] = h
rotated_hij = torch.stack(rotated_hij)
hij = self._rotate_in_reverse(rotated_hij, reverse_idx)
score = self._soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
return score
class BatchSwTorch(nn.Module):
def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
super(BatchSwTorch, self).__init__()
self.device = device
self.sw = SwTorch(unroll=unroll, NINF=NINF, device=device)
def forward(self, batch_score_matrix, batch_lengths, gap=0, temp=1.0):
n_batches = batch_score_matrix.shape[0]
ret = torch.empty((n_batches,), dtype=batch_score_matrix.dtype, device=self.device)
for i in range(n_batches):
ret[i] = self.sw(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp)
return ret
ちなみに最初はシンプルなコードと比較しようと思ったので、この時点ではまだtorch.jit
は使っていません。このコードの結果は以下の通りです。
平均実行時間 (sec) | |
numpy | 34.5 |
JAX jit版 | 0.0142 |
JAXのコードをそのままPyTorchにする | 7.89 |
見ての通り、JAXが圧倒的。PyTorchもnumpyに比べて速くなってはいるのでGPUを使っている効果が出ていると考えられますが、それ以上にJAXが速い。PyTorchと比較してJAXのほうが556倍も速いという結果でした。JAXのほうがバグっているのか?とも一瞬思ったのですが、ちゃんと正しい答えを出力しているし、nsysでプロファイル結果を取ってみた限りそれっぽい時間で1回の計算が終わっているので、測定ミスでもなさそうでした。
というわけで、圧倒的にPyTorchがこのままでは遅いので、高速化したバージョンを作成して評価したので、次で紹介します。
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する
PyTorchを普段から使っていると気にならない部分ではありますが、CUDAの高速化のつもりで考えると、 Batchの軸を行列の一番内側にもってくるほうがCUDA的には速くなりそうな気がします。また、JAXはJITを使っているのでPyTorchもJIT使うほうがいいだろうということでJITを使いました。
これに伴ってコードは以下のように変更しました。
from typing import Tuple
def _make_batch_mask(batch_score_matrix, batch_lengths):
a, b, batch_size = batch_score_matrix.shape
real_a = batch_lengths[:, 0]
real_b = batch_lengths[:, 1]
mask_a = torch.arange(a, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_a[None, :]
mask_b = torch.arange(b, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_b[None, :]
mask = mask_a[:, None] & mask_b[None, :]
return mask
def _logsumexp(y: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
y = torch.maximum(y, NINF)
return torch.logsumexp(y, dim=axis)
def _logsumexp_with_mask(y: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
y = torch.maximum(y, NINF)
return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, dim=axis, keepdim=True)[0]), dim=axis))
def _soft_maximum(x: torch.Tensor, temp: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
return temp*_logsumexp(x/temp, axis=axis, NINF=NINF)
def _soft_maximum_with_mask(x: torch.Tensor, temp: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
return temp*_logsumexp_with_mask(x/temp, axis=axis, mask=mask, NINF=NINF)
def _rotate(batch_score_matrix: torch.Tensor, NINF: torch.Tensor, rotated_batch_score_matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso\
r, torch.Tensor]:
a, b, batch_size = batch_score_matrix.shape
n,m = (a+b-1),(a+b)//2
ar = torch.flip(torch.arange(a, device=batch_score_matrix.device), [0])[:, None]
br = torch.arange(b, device=batch_score_matrix.device)[None,:]
i,j = (br-ar)+(a-1),(ar+br)//2
rotated_batch_score_matrix[:, :, :] = NINF
rotated_batch_score_matrix[i, j, :] = batch_score_matrix
return rotated_batch_score_matrix, i, j
def _rotate_in_reverse(rotated_dp_matrix, i, j):
return rotated_dp_matrix[i, j]
def _get_prev_gap_cell_score(cond, true, false):
return cond*true + (1-cond)*false
@torch.jit.script
def _step(h2, h1, gap_cell_condition, rotated_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
prev_gap_cell_true[1:, :] = h1[:-1, :]
prev_gap_cell_false[:-1, :] = h1[1:, :]
h1_T = _get_prev_gap_cell_score(
gap_cell_condition,
prev_gap_cell_true,
prev_gap_cell_false,
)
a = h2 + rotated_score_matrix
g0 = h1 + gap
g1 = h1_T + gap
s = rotated_score_matrix
h0 = torch.stack([a, g0, g1, s], -1)
h0 = _soft_maximum(h0, temp, axis=-1, NINF=NINF)
return h1, h0, h0
@torch.jit.script
def _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
n, _, _ = rotated_batch_score_matrix.shape
rotated_hij = torch.empty((n, init_h1.shape[0], init_h1.shape[1]), dtype=init_h1.dtype, device=init_h1.device)
h1 = init_h1
h0 = init_h0
h1[:, :] = NINF
h0[:, :] = NINF
for i in range(n):
h1, h0, h = _step(h1, h0, gap_cell_condition=gap_cell_condition[i], rotated_score_matrix=rotated_batch_score_matrix[i], gap=gap, temp=temp, NINF=NINF, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false,)
rotated_hij[i] = h
return rotated_hij
@torch.jit.script
def batch_sw_func(batch_score_matrix, batch_lengths, gap, temp, NINF, rotated_batch_score_matrix, init_h1, init_h0, prev_gap_cell_true, prev_gap_cell_false):
transposed_batch_score_matrix = batch_score_matrix.permute(1, 2, 0)
mask = _make_batch_mask(transposed_batch_score_matrix, batch_lengths)
masked_batch_score_matrix = transposed_batch_score_matrix + NINF * (~mask)
rotated_batch_score_matrix, reverse_idx_i, reverse_idx_j = _rotate(masked_batch_score_matrix, NINF=NINF, rotated_batch_score_matrix=rotated_batch_score_matrix)
a, b, batch_size = transposed_batch_score_matrix.shape
n, m, _ = rotated_batch_score_matrix.shape
gap_cell_condition = (torch.arange(n, device=rotated_batch_score_matrix.device)+a%2)%2
rotated_hij = _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false, NINF=NINF)
hij = _rotate_in_reverse(rotated_hij, reverse_idx_i, reverse_idx_j)
score = _soft_maximum_with_mask(hij.reshape(a*b,batch_size), temp=temp, mask=mask.reshape(a*b, batch_size), axis=0, NINF=NINF)
return score
次に紹介する高速化の関係で一時領域も引数として与えていますが気にしないでください。後ほど説明します。
このコードの測定結果も加えると以下の通りです。
平均実行時間 (sec) | |
numpy | 34.5 |
JAX jit版 | 0.0142 |
JAXのコードをそのままPyTorchにする | 7.89 |
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する | 0.0655 |
正直、この時点でJAXと並ぶだろうとやる前は思っていたのですが、JAXのほうがまだPyTorchの4.6倍速いという結果でした。JAX速い・・・。でも、できることはまだある!ということでもう一工夫やります。
CUDA Graphsを使う
PyTorchのコードのプロファイル結果を見るとかなり実行時間の短いCUDA Kernelが大量に実行されているという状態でした。このため、CUDA Kernelの実行のオーバーヘッドがかなり入っているのでは?と考えて、これを削減するCUDA Graphsを使ってみます。
CUDA Graphsが何かわからない方はこちらの記事をご覧ください。
さて、CUDA Graphsで実行するようのコードとしては一つ前のコードのtorch.jit.trace
でコンパイルしたものを利用します。CUDA Graphsで実行できるようにするために、一時領域を一部入力として入れていました。
この測定結果は以下の通りです。
平均実行時間 (sec) | |
numpy | 34.5 |
JAX jit版 | 0.0142 |
JAXのコードをそのままPyTorchにする | 7.89 |
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する | 0.0655 |
CUDA Graphsを使う | 0.0324 |
なんと、JAXのほうがPyTorchのコードよりも2.2倍速いという結果でした。CUDA Graphsでも勝てないなんて・・・。
現状のPyTorchのコードの敗因
自分なりにnsysを使ってプロファイルとって結果を見た印象では主に以下の二つの原因があると考えています。
- PyTorchのJITがJAXに比べてあまりfuseしてくれない
- そもそも入力サイズがGPUで実行するには小さすぎる
1についてはPyTorchとJAXの二つのコードを見てみるとJAXのコードのほうがJITされて1つのCUDA Karnelの実行時間が長い印象でした。JAX、PyTorchどちらもどの関数がfuseされたかをどうやってみるのかわからないので憶測になりますが、おそらくJAXのほうがより多くの処理を1つのCUDA Karnelまとめてくれていて、結果としてCUDA Karnelの実行数が減り、CUDA Karnelの実行オーバーヘッド小さくなったためJAXのほうが速くなった、ということを考えています。
2に関しては、そもそもGPUで実行するには入力サイズが小さすぎる印象です。実際配列の長さや配列ペアの数を大きくしてもPyTorchのコードはあまり実行時間が増加しないことを確認しています。じゃあ、入力サイズを大きくすればいいのでは?とも思ったのですが、今回のSmooth Smith Watermanではあと2,3倍くらいにはしてもよさそうですが、どこまでできるかは問題依存なため、ひとまずそのままにしておきました。なんとなく実際に使うときもあと数倍くらいは大きくできそうだけど、100倍とかはあまりつかわなさそうだなと思っています。ただ、この辺りはいろいろ意見が分かれそうかなと思っています。
終わりに
今回は個人的に前々から気になっていた「JAXって速いの?」という問いに答えるための検証の一環で行いました。結果はまさかのPyTorchと比べてこんなに差がでるとは、という感じした。ただ、PyTorchのJITはまだ使い慣れていない感があるので、何か高速化のアイディアが浮かんだら再チャレンジしたいと思います。