読者です 読者をやめる 読者になる 読者になる

StatModeling Memorandum

StanとRでベイズ統計モデリングします. たまに書評.

WAICとWBICを事後分布から計算する

Stan R

前回の理論的なまとめを踏まえてStanでWAICとWBICを計算してみます。

今回は例題として混合正規分布から発生させたデータ100個を用いて、2種類のモデルで推定を行い、それぞれに対してWAICとWBICを求めてみます。まずはデータ生成部分のRコードは以下です。

N <- 100
a_true <- 0.4
mean1 <- 0
mean2 <- 3
sd1 <- 1
sd2 <- 1
set.seed(1)
Y <- c(rnorm((1-a_true)*N, mean1, sd1), rnorm(a_true*N, mean2, sd2))

次にモデルその1の説明です。ここでは2つの正規分布のうち平均0の方は固定で、もう片方の正規分布の平均(mu)とそれらの混ぜ具合(a)を推定することにします。Stanコードは以下です(model1a.stan)。

data {
  int<lower=1> N;
  vector[N] Y;
}

parameters {
  real<lower=0, upper=1> a;
  real<lower=-50, upper=50> mu;
}

model {
  for(n in 1:N)
    target += log_sum_exp(
      log(1-a) + normal_lpdf(Y[n] | 0, 1),
      log(a) + normal_lpdf(Y[n] | mu, 1)
    );
}

generated quantities {
  vector[N] log_likelihood;
  int index;
  real y_pred;
  for(n in 1:N)
    log_likelihood[n] = log_sum_exp(
      log(1-a) + normal_lpdf(Y[n] | 0, 1),
      log(a) + normal_lpdf(Y[n] | mu, 1)
    );
  index = bernoulli_rng(a);
  y_pred = normal_rng(index == 1 ? mu : 0, 1);
}
  • 23~27行目: RでWAICの計算をする際に1データごとの対数尤度が必要になるのでそれを計算しています。
  • 28~29行目: Rで汎化誤差(generalization error)の計算をする際に予測分布が必要になるのでそれを計算しています。

WBICは逆温度が1/log(データ数)の時の事後分布を用いて計算されます。この事後分布をStanで求めるには対数尤度の部分だけ 1/log(データ数) を掛けておけばよいと理解しています。よってStanコードは以下になりました(model1b.stan)。

data {
  int<lower=1> N;
  vector[N] Y;
}

parameters {
  real<lower=0, upper=1> a;
  real<lower=-50, upper=50> mu;
}

model {
  for(n in 1:N)
    target += 1/log(N) * log_sum_exp(
      log(1-a) + normal_lpdf(Y[n] | 0, 1),
      log(a) + normal_lpdf(Y[n] | mu, 1)
    );
}

generated quantities {
  vector[N] log_likelihood;
  for(n in 1:N)
    log_likelihood[n] = log_sum_exp(
      log(1-a) + normal_lpdf(Y[n] | 0, 1),
      log(a) + normal_lpdf(Y[n] | mu, 1)
    );
}
  • 13行目: この部分が上記で述べたところになります。

次にモデルその2の説明です。ここでは1つの正規分布であてはめを行い、平均(mu)と標準偏差s)を推定することにします。Stanコードは以下です(model2a.stan)。

data {
  int<lower=1> N;
  vector[N] Y;
}

parameters {
  real mu;
  real<lower=0> s;
}

model {
  Y ~ normal(mu, s);
}

generated quantities {
  vector[N] log_likelihood;
  int index;
  real y_pred;
  for(n in 1:N)
    log_likelihood[n] = normal_lpdf(Y[n] | mu, s);
  y_pred = normal_rng(mu, s);
}

同様に逆温度が1/log(データ数)の時のStanコードは以下です(model2b.stan)。

data {
  int<lower=1> N;
  vector[N] Y;
}

parameters {
  real mu;
  real<lower=0> s;
}

model {
  for(n in 1:N)
    target += 1/log(N) * normal_lpdf(Y[n] | mu, s);
}

generated quantities {
  vector[N] log_likelihood;
  for(n in 1:N)
    log_likelihood[n] = normal_lpdf(Y[n] | mu, s);
}

ここからが本番です。これらのモデルを使って推定を行い、WAICとWBICを定義に沿って計算していくRコードは以下になりました。

library(rstan)

N <- 100
a_true <- 0.4
mean1 <- 0
mean2 <- 3
sd1 <- 1
sd2 <- 1
set.seed(1)
Y <- c(rnorm((1-a_true)*N, mean1, sd1), rnorm(a_true*N, mean2, sd2))

data <- list(N=N, Y=Y)
fit1a <- stan(file='model/model1a.stan', data=data, iter=11000, warmup=1000, seed=123)
fit1b <- stan(file='model/model1b.stan', data=data, iter=11000, warmup=1000, seed=123)
fit2a <- stan(file='model/model2a.stan', data=data, iter=11000, warmup=1000, seed=123)
fit2b <- stan(file='model/model2b.stan', data=data, iter=11000, warmup=1000, seed=123)
ms1a <- extract(fit1a)
ms1b <- extract(fit1b)
ms2a <- extract(fit2a)
ms2b <- extract(fit2b)

generalization_error <- function(ms) {
  dens <- density(ms$y_pred)
  f_pred <- approxfun(dens$x, dens$y, yleft=1e-18, yright=1e-18)
  f_true <- function(x) (1-a_true)*dnorm(x, mean1, sd1) + a_true*dnorm(x, mean2, sd2)
  f_ge <- function(x) f_true(x)*(-log(f_pred(x)))
  # f_en <- function(x) f_true(x)*(-log(f_true(x)))
  # entropy <- integrate(f_en, lower=-6, upper=9)$value
  ge <- integrate(f_ge, lower=-6, upper=9)$value
  return(ge)
}

waic <- function(log_likelihood) {
  training_error <- - mean(log(colMeans(exp(log_likelihood))))
  functional_variance_div_N <- mean(colMeans(log_likelihood^2) - colMeans(log_likelihood)^2)
  waic <- training_error + functional_variance_div_N
  return(waic)
}

wbic <- function(log_likelihood){
  wbic <- - mean(rowSums(log_likelihood))
  return(wbic)
}

ge1 <- generalization_error(ms1a)
waic1 <- waic(ms1a$log_likelihood)
wbic1 <- wbic(ms1b$log_likelihood)
ge2 <- generalization_error(ms2a)
waic2 <- waic(ms2a$log_likelihood)
wbic2 <- wbic(ms2b$log_likelihood)
  • 22~31行目: 今回は真のモデル(f_true)が分かっているので汎化誤差(ge)を求めることができます。23~24行目では予測分布(を近似した関数)を求めています。27~28行目はコメントアウトしていますが、エントロピーを求めたい場合に使います。
  • 33~38行目: WAICを求めています。渡辺ベイズ本のp.167参照。
  • 40~43行目: WBICを求めています。渡辺先生のWebページを参照。

結果は以下の表のようになりました。

汎化誤差WAICWBIC
モデル11.9311.913193.7
モデル21.9971.980201.1

今回はモデル1とモデル2の比較ではモデル1の方がよさそうと判断できそうです。ちなみにデータ生成部分のmean2を0にするとモデル2の方がWAICもWBICも低くなります。

将来はstanの中でWAICを算出する機能が出てくるとは思いますが、いつになるか分からないので当分は自分で計算しようと思います。計算コストも低いですし。

渡辺先生の「広く使える情報量規準(WAIC) 」というページの(注0)にBDA3と渡辺ベイズ本の定義が異なることがちゃんと書かれてありました。

Nをデータ数として、渡辺ベイズ本のWAICの定義式を2N倍したものが、BDA3での定義になっています。Aki Vehtari &Andrew Gelman (2014) "WAIC and cross-validation in Stan" 内の「lpd hat (computed log pointwise predictive density)」にマイナスをつけてデータ数Nで割ったものが、本記事のRコード内のtraining_errorと一致し、p_waic hatをデータ数Nで割ったものがfunctional_variance_div_Nと一致します。数値も一致することを確認しました。

個人的には、2N倍していない方、すなわちNを大きくしたときには汎化誤差に近づいた値を得た方がよいと思っていて、それはこれから何回もWAICの値そのものを見ていくうちに、「このモデル(分布)でこれぐらいの数値だとかなり悪そうだな」というような経験・手ごたえが蓄積できたらいいなぁと思っているためです。AICやDICとどうしても比べたい場合は2N倍した方を使うのがよいでしょう。