StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

ノンパラベイズ(ディリクレ過程)の実装

BUGS bookの11章の8.1節のディリクレ過程の写経です。データは以下のサポートページ(11.8.1: Galaxy clustering: Dirichlet process mixture models)でWinBUGS用のodcファイルで配布されています。

WinBUGSをインストールしていない人のために.RDataにしたものをここに置いておきます。

ディリクレ過程については詳しくはここでは述べません。『続・わかりやすいパターン認識』が分かりやすいと思います。BUGS bookにある実装のアイデアはCongton先生の本『Bayesian Statistical Modelling』から来ているようです。

Bayesian Statistical Modelling (Wiley Series in Probability and Statistics)

Bayesian Statistical Modelling (Wiley Series in Probability and Statistics)

  • 作者:Congdon, Peter
  • 発売日: 2007/01/05
  • メディア: ハードカバー

他のBUGSでの適用例はOhlssen et al. (2007) "Flexible random-effects models using Bayesian semi-parametric models: Applications to institutional comparisons" が有名のようです。ググればpdfがヒットします。

JAGSコードはBUGS bookに載っているものでそのまま動きました。以下です。

model {
   for (i in 1:n) {
      velocity[i] ~ dnorm(mu[i], tau[i])
      mu[i] <- mu.mix[group[i]]
      tau[i] <- tau.mix[group[i]]
      group[i] ~ dcat(pi[])
      for (j in 1:C) {
         gind[i,j] <- equals(j, group[i])
      }
   }
   p[1] <- q[1]
   for (j in 2:C) {
      p[j] <- q[j]*(1 - q[j-1])*p[j-1]/q[j-1]
   }
   for (j in 1:C) {
      q[j] ~ dbeta(1, alpha)
      pi[j] <- p[j]/sum(p[])
      mu.mix[j] ~ dnorm(amu, mu.prec[j])
      mu.prec[j] <- bmu*tau.mix[j]
      tau.mix[j] ~ dgamma(aprec, bprec)
   }
   alpha <- 1
   #  Could replace constant alpha with a prior
   # alpha ~ dgamma(2, 4)
   # alpha ~ dunif(0.3, 10)
   amu ~ dnorm(0, 0.001)
   bmu ~ dgamma(0.5, 50)
   aprec <- 2
   bprec ~ dgamma(2, 1)
   K <- sum(cl[])
   for (j in 1:C) {
      sumind[j] <- sum(gind[,j])
      cl[j] <- step(sumind[j]-1+0.001) # cluster j used in this iteration
   }
   for (j in 1:ndens) {
      for (i in 1:C) {
         dens.cpt[i,j] <- pi[i]*sqrt(tau.mix[i] / (2*3.141592654))*
                           exp(-0.5*tau.mix[i]*(mu.mix[i] - dens.x[j])*(mu.mix[i] - dens.x[j]))
      }
      dens[j] <- sum(dens.cpt[,j])
   }
}
  • 11-14,16,17行目: stick-breaking priorを定義している部分になります。
  • 19,20,27,28,29行目: 今回やっていることは混合正規分布の拡張みたいなものです。その1つ1つの正規分布の分散を自由にしてしまうと収束しません。そこで色々苦労して事前分布を設定しているようですが、僕にはイマイチよく分かりませんでした。 BUGSコードの最後の方はディリクレ過程から推定されたクラスター数と混合正規分布による密度を算出しています。計算時間はSurface Pro 3でn.iter=20000, thin=10で3chainで5分ぐらいだったような。

次にStanでやりました。Stanでやる場合は離散パラメータを消去する必要がありますがそんなに難しくありません。しかし収束がJAGSより若干よくない気がするし、計算が速いわけでもありません。Stanのメーリングリストの2013年12月の議論ですが、Bobさんから以下のような意見があります。

If the data's not too huge and you can approximate the required Dirichlet process(es) with a fairly low dimensional Dirichlet, then you could use Stan. If not, I'd suggest asking Dunson and Xing what they used for the paper.

JAGSやStanでのディリクレ過程は拡張性が容易である一方、小規模問題向けであると言えると思います。Stanコードは以下になりました。

data {
   int N;
   int C;
   int M;
   vector[N] Velocity;
   vector[M] X_mesh;
}

parameters {
   vector<lower=0, upper=1>[C] q;
   real mu_mu;
   real<lower=0> s_mu;
   ordered[C] mu_mix;
   vector<lower=0>[C] sigma_mix;
   real<lower=0> b;
}

transformed parameters {
   vector<lower=0>[C] p;
   simplex[C] theta;

   p[1] <- q[1];
   for (j in 2:C)
      p[j] <- q[j]*(1 - q[j-1])*p[j-1]/q[j-1];
   for (j in 1:C)
      theta[j] <- p[j]/sum(p);
}

model {
   for (n in 1:N) {
      real ps[C];
      for (j in 1:C)
         ps[j] <- log(theta[j]) + normal_log(Velocity[n], mu_mix[j], sigma_mix[j]);
      increment_log_prob(log_sum_exp(ps));
   }
   q ~ beta(1, 1);

   for (j in 1:C) {
      mu_mix[j] ~ normal(mu_mu, s_mu);
      sigma_mix[j] ~ gamma(5, b);
   }
   b ~ gamma(0.01, 0.01);
}

generated quantities {
   vector[C] dens_cpt[M];
   vector[M] dens;
   for (i in 1:M) {
      for (j in 1:C)
         dens_cpt[i,j] <- theta[j] * 1/(sqrt(2*pi())*sigma_mix[j])*exp(-0.5*pow((X_mesh[i] - mu_mix[j])/sigma_mix[j], 2));
      dens[i] <- sum(dens_cpt[i]);
   }
}
  • 22~26,36行目: stick-breaking priorの部分です。log取った記述の方が安定するかなと思ったのですがあまり変わりませんでした。
  • 31~34行目: 離散パラメータをsumming outしています(この記事参照)。従ってStanではクラスター数に相当するものは推定できません。
  • 38~42行目: JAGSでよく分からなかった事前分布のところの代替案です。あまり代わり映えはしませんが…。
  • 13行目: 混合正規分布の平均パラメータをordered vector型にすることで少しだけ収束しやすくなりました。
  • 48~52行目: 混合正規分布による密度を算出しています。

キックするRコードは以下になります。

library(rstan)

stan_fit <- stan(file="model/model2.stan", chains=0)

load('input/data.RData')
N <- data$n
C <- data$C
M <- data$ndens
Velocity <- data$velocity/20
X_mesh <- data$dens.x/20
data <- list(N=N, C=C, M=M, Velocity=Velocity, X_mesh=X_mesh)

fit <- stan(
   fit=stan_fit,
   data=data,
   pars=c('dens', 'theta', 'mu_mu', 's_mu', 'mu_mix', 'sigma_mix', 'b'),
   chains=3, iter=3300, warmup=300, thin=3, seed=2
)

・9-10行目: スケールを1ぐらいにするために20で割っています。 計算時間はSurface Pro 3で1chainあたり数分でした。

結果は以下になります。

JAGSの推定結果

クラスター数Kの中央値は7、95%信用区間は4-10でした。各正規分布の重みpiは以下の図のようになりました。

f:id:StatModeling:20201114121449p:plain

点はMCMCサンプルの中央値で、線は95%信用区間です。また混合正規分布による密度densは以下のようになりました。

f:id:StatModeling:20201114121445p:plain

赤線はMCMCサンプルの中央値、薄いオレンジ帯は同じく95%信用区間、濃いオレンジ帯は同じく50%信用区間です。灰色の棒はデータのヒストグラム(binwidth=range/30)です。

Stanの推定結果

正規分布の重みthetaは以下の図のようになりました。

f:id:StatModeling:20201114121441p:plain

点はMCMCサンプルの中央値で、線は95%信用区間です。また混合正規分布による密度densは以下のようになりました。

f:id:StatModeling:20201114121437p:plain

赤線はMCMCサンプルの中央値、薄いオレンジ帯は同じく95%信用区間、濃いオレンジ帯は同じく50%信用区間です。灰色の棒はデータのヒストグラム(binwidth=range/30)です。

JAGSもStanも似たような結果になりました。