NUTSとADVI(自動変分ベイズ)の比較
RStan2.9.0がリリースされました。今まで{rstan}
パッケージのsampling
関数を使っていたところを、vb
関数に変更するだけでサンプリングのアルゴリズムをNUTSからADVI(Automatic Differentiation Variational Inference)に変更することができます。ADVIはユーザーが変分下限の導出や近似分布qを用意をすることなしに、自動的に変分ベイズしてくれます。得られるアウトプットはNUTSとほぼ同様で近似事後分布からの乱数サンプルです。ウリはスピードです。NUTSもADVIもデフォルトのオプションのまま実行して、NUTSと比べて50倍ぐらいスピードが出ることもあります。
NUTSと同様にADVIは効率的な探索のため偏微分を使っているので、離散値をとるパラメータは使えませんが、やはり同様に離散パラメータを消去すれば実行できます。そして、微分可能ならばどんな関数でもADVIで変分ベイズ化可能です。アルゴリズムの詳細は、arXivの論文(URL)を、調節するパラメータの詳細はStanのマニュアルを読むとよいでしょう。
気になる部分は、NUTSと推定値が同じぐらいになるかです。詳しく言うと変分ベイズの平均場近似の仮定(事後分布は因数分解可能)でどれほど推定が悪くなるかです。いくつか試しましたので順に見ていきます。
Eight Schools
まずはRStanの公式の導入ページにある8学校の問題をやります。モデルはそのまま拝借するとして、NUTSとADVIを実行するRコードの例は以下の通りです。
library(rstan) schools_dat <- list( J=8, y=c(28, 8, -3, 7, -1, 1, 18, 12), sigma=c(15, 10, 16, 11, 9, 11, 10, 18) ) stanmodel <- stan_model(file='model/8schools.stan') fit <- sampling(stanmodel, data=schools_dat, iter=1000, warmup=500, chains=4, seed=123) fit_vb <- vb(stanmodel, data=schools_dat, output_samples=2000, seed=123)
- 12行目:ADVIを実行しています。アルゴリズムは2種類選択できて、
meanfield
(平均場)とfullrank
です。デフォルトはmeanfield
です。NUTSで最終的に得られるMCMCサンプルの数は(iter-warmup)/thin*chains
なので、thinはデフォルトは1であることを考慮すると今回は2000です。これにあわせるため、「output_samples=2000
」を指定しています。
さて推定結果は以下になりました。横軸がパラメータで縦軸が推定値です。乱数サンプルの中央値と95%ベイズ信頼区間を表示しています。少々のズレはありますね。
ロジスティック回帰
次の例はロジスティック回帰です。今までの僕の個人的な経験だと、回帰はADVIはそのままではうまくいかないことがあります。切片と回帰係数の同時分布はネチョネチョに関係してる同時分布なので、単純な平均場近似がうまく機能していないのではないかと予想しています。例えば、以下のStanコードでロジスティック回帰をします(分かりやすさのためvector
型は使っていません)。
data { int N; real X1[N]; real X2[N]; int Y[N]; } parameters { real b[3]; } model { for (n in 1:N) Y[n] ~ bernoulli_logit(b[1] + b[2]*X1[n] + b[3]*X2[n]); }
実行するRコードは以下になります。8学校の場合と同様です。
library(rstan) data(mtcars) data <- list(N=nrow(mtcars), X1=mtcars$mpg/20, X2=mtcars$am, Y=mtcars$vs) stanmodel <- stan_model(file='model/mtcars.stan') fit <- sampling(stanmodel, data=data, iter=1000, warmup=500, chains=4, seed=123) fit_vb <- vb(stanmodel, data=data, output_samples=2000, seed=123)
この結果は以下になります。図の凡例も8学校と同じです。
事後分布の中央値はNUTSとそこまで変わらないのですが、バラツキが小さめに推定されやすいようです(infer.NETもそうでした)。対処方法としては、説明変数の行列にQR分解を使って、モデルを再パラメータ化するという方法があるようです。
https://groups.google.com/forum/#!searchin/stan-dev/ADVI$20QR/stan-dev/qkQsG0f8krk/3hmIB9afBQAJ
Stanコードを書かなくても基本的なモデルをRから実行できるようにした{rstanarm}
パッケージにおいては、そのようなQR分解を使った再パラメータ化がオプションで用意されているようです。すごい。
https://github.com/stan-dev/rstanarm/blob/23dd5558e5eee31f12a8e1f9d1be9279b830ae88/R/stan_glm.fit.R
トピックモデルのLDA
データ(data2
)とStanコードは以前の記事と同じものを使いました。時間も測り直しました。
- NUTSの場合: 4900.91 seconds (Total) (
iter=1000, warmup=500, chains=1, seed=123
で実行) - ADVIの場合: 116.31 seconds (Total) (
output_samples=2000, seed=123
で実行)
驚きの50倍近いスピードです。結果は以下になります。以前のLDAの記事と同様、各文書のトピック分布thetaをプロットしました。1つ目の図がADVIで、2つ目の図がNUTSの結果です。
- ADVI
- NUTS
ここで、黒の横棒はデータを生成する際に使用した真の値、色の丸点は乱数サンプルの中央値、色の線は乱数サンプルの80%ベイズ信頼区間です。やはり、ADVIでは少しズレているところが散見されます。しかしながら、このスピードは本当に魅力的で、データが非常に多い場合にはADVIがいい選択肢になると思いました。