前回の記事のスピードアップをします。
まずは分散共分散行列をコレスキー分解して multi_normal() から multi_normal_cholesky() を使うようにする方法です。このテの高速化の基本とのことです。コレスキー分解をするメリットはzがi.i.d.から生成される、すなわち、
に従う変数とすると、分散共分散行列をコレスキー分解して得られた行列をLとすると、
に従うようになります。
Stanではmulti_normal_cholesky()
という関数が用意されているのでこれを使います。Stanコードは以下になります(GP2.stan)。
data { int<lower=1> N1; int<lower=1> N2; vector[N1] X1; vector[N1] Y1; vector[N2] X2; } transformed data { int<lower=1> N; vector[N1+N2] X; vector[N1+N2] Mu; cov_matrix[N1+N2] Cov; matrix[N1+N2,N1+N2] L; N <- N1 + N2; for (n in 1:N1) X[n] <- X1[n]; for (n in 1:N2) X[N1 + n] <- X2[n]; for (i in 1:N) Mu[i] <- 0; for (i in 1:N) for (j in 1:N) Cov[i,j] <- exp(-pow(X[i] - X[j],2)) + if_else(i==j, 0.1, 0.0); L <- cholesky_decompose(Cov); } parameters { vector[N2] y2; } model { vector[N] y; for (n in 1:N1) y[n] <- Y1[n]; for (n in 1:N2) y[N1 + n] <- y2[n]; y ~ multi_normal_cholesky(Mu, L); }
Cov
はデータから決まるのでコレスキー分解は1度しか実行しないことに注意です。
キックするRコードは同じなので省略します。計算時間は1chainあたり17sでした。結果の図は前回と変わりません。
次に条件付き多変量正規分布の解析的な結果を使う方法です。Y1
とy2
を(同時に)生成する多変量正規分布が以下のように表せるとき、Y1
で条件付けたy2
の分布は手計算で求めることができて、以下のようになります(PRML 2.3.1を参照)
今はとしているので、これを使ったStanコードが以下になります(GP3.stan)。
data { int<lower=1> N1; int<lower=1> N2; vector[N1] X1; vector[N1] Y1; vector[N2] X2; } transformed data { vector[N2] Mu; matrix[N2,N2] L; { matrix[N1,N1] Sigma; matrix[N2,N2] Omega; matrix[N1,N2] K; matrix[N2,N1] K_transpose_div_Sigma; matrix[N2,N2] Tau; for (i in 1:N1) for (j in 1:N1) Sigma[i,j] <- exp(-pow(X1[i] - X1[j],2)) + if_else(i==j, 0.1, 0.0); for (i in 1:N2) for (j in 1:N2) Omega[i,j] <- exp(-pow(X2[i] - X2[j],2)) + if_else(i==j, 0.1, 0.0); for (i in 1:N1) for (j in 1:N2) K[i,j] <- exp(-pow(X1[i] - X2[j],2)); K_transpose_div_Sigma <- K' / Sigma; # ':transpose Mu <- K_transpose_div_Sigma * Y1; Tau <- Omega - K_transpose_div_Sigma * K; for (i in 1:N2) for (j in (i+1):N2) Tau[i,j] <- Tau[j,i]; L <- cholesky_decompose(Tau); } } parameters { vector[N2] z; } model { z ~ normal(0,1); } generated quantities { vector[N2] y2; y2 <- Mu + L * z; }
このGP3.stanはGP2.stanの時と異なり、すでにデータで条件付けた多変量正規分布を求めているため尤度を計算する部分が必要ありません。よって、その多変量正規分布から予測値y2
を生成するところだけ、parameters
ブロック以下で書くことになります。ここはコレスキー分解を使うことで、その恩恵を大きく受けることができます。
計算時間は1chainあたり0.4sでした。尤度の計算をしてパラメーターを探索するようなことをしていませんので前回の記事より100倍近く速くなっている上に収束もいいです。しかしこのわざはクラス分類の時には使えません。これが回帰の時だけStanをすすめる理由です。結果の図は前回と変わりません。
最後にhyper parameterも推定するフルベイズのStanコードの例は以下になります(GP4.stan)。
data { int<lower=1> N; vector[N] X; vector[N] Y; int<lower=1> N_new; vector[N_new] X_new; } transformed data { vector[N] Mu; for (i in 1:N) Mu[i] <- 0; } parameters { real<lower=0> eta_sq; real<lower=0> rho_sq; real<lower=0> sigma_sq; } transformed parameters { matrix[N,N] cov; for (i in 1:N) for (j in 1:N) cov[i,j] <- eta_sq * exp(-rho_sq * pow(X[i] - X[j],2)) + if_else(i==j, sigma_sq, 0.0); } model { Y ~ multi_normal(Mu, cov); eta_sq ~ cauchy(0,5); rho_sq ~ cauchy(0,5); sigma_sq ~ cauchy(0,5); } generated quantities { vector[N_new] y_new; vector[N_new] z; vector[N_new] Mu_new; matrix[N_new,N_new] L; { matrix[N_new,N_new] Omega; matrix[N,N_new] K; matrix[N_new,N] K_transpose_div_cov; matrix[N_new,N_new] Tau; for (i in 1:N_new) for (j in 1:N_new) Omega[i,j] <- eta_sq * exp(-rho_sq * pow(X_new[i] - X_new[j],2)) + if_else(i==j, sigma_sq, 0.0); for (i in 1:N) for (j in 1:N_new) K[i,j] <- eta_sq * exp(-rho_sq * pow(X[i] - X_new[j],2)); K_transpose_div_cov <- K' / cov; # ':transpose Mu_new <- K_transpose_div_cov * Y; Tau <- Omega - K_transpose_div_cov * K; for (i in 1:N_new) for (j in (i+1):N_new) Tau[i,j] <- Tau[j,i]; L <- cholesky_decompose(Tau); } for (j in 1:N_new) z[j] <- normal_rng(0, 1); y_new <- Mu_new + L*z; }
hyper parameterが変化する→cov
が変化する なのでコレスキー分解を使った高速化はできません。model
ブロック(サンプルがいらない場合)もしくはtransformed parameters
ブロック(サンプルが欲しい場合)でcov
を作ることになります。
- 15~17行目:値を0以上に制限し、
- 29~31行目:cauchy分布を使うことで、すその広いhalf cauchy分布からhyper parameterを生成しています。
- 34~66行目:新しいX_newに対応する部分だけgenerated quantitiesブロックに書きます。
計算時間は1chainあたり10秒以下でした。結果の図は以下の通りです(濃い灰帯は50%信用区間です)。
なお、 y[i] ~ normal(beta * Treatment[i] + f[i], sigma_sq)
(f ~ GPに従う)のように他の情報も取り込んだ回帰ができると非常にうれしいと思ったのですが、少なくとも現状のStanではこれをやると全く収束しなくなります。また、補助変数法も機能しませんでした。自分でC++等で実装する必要がありそうです。