読者です 読者をやめる 読者になる 読者になる

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

Zero-Inflated Poisson分布を使った来店人数などのモデリング

東京R勉強会(#TokyoR)で「100人のための統計解析 - 和食レストラン編」というタイトルで発表してきました。スライドは以下になります。

前半の散布図行列に関しては別途記事を書きましたのでそちらを参照してください。この記事では後半のBUGSコード&Stanコードを説明します。

まずデータですがここからダウンロードできます。あり/なしを1/0に置き換えて、列名を適当に英語化したものは以下のようになっています(case1.txt)。1人1行で1000行あります。

cIDagegenderfamily
.structure
CMnum.visitamountset
.flag
noodle
.flag
ricebowl
.flag
dessert
.flag
sidemenu
.flag
softdrink
.flag
alcohol
.flag
554113001010000
946112001000000
1127113001100010
1235113122001010010
2120113000011000
83936023001100000
90833023001110000
9653502367001110001
97946024120001100101
98734023001110000

まずは個体差(個人差)なしモデルから。BUGSコードは以下になります(model1.bugs)。

model {
   for (i in 1:N){
      Y.visit[i]  ~ dpois(m.visit[i])
      Y.amount[i] ~ dpois(m.amount[i])
      m.visit[i]  <- group[i] * lambda.visit[i]
      m.amount[i] <- group[i] * lambda.amount[i]
      group[i] ~ dbern(theta[i])

      logit(theta[i])       <- alpha1 + inprod(X[i,], beta1[])
      log(lambda.visit[i])  <- alpha2 + inprod(X[i,], beta2[])
      log(lambda.amount[i]) <- alpha3 + inprod(X[i,], beta3[])
   }

   alpha1 ~ dnorm(0, 1.0E-4)
   alpha2 ~ dnorm(0, 1.0E-4)
   alpha3 ~ dnorm(0, 1.0E-4)
   for (k in 1:K){
      beta1[k] ~ dnorm(0, 1.0E-4)
      beta2[k] ~ dnorm(0, 1.0E-4)
      beta3[k] ~ dnorm(0, 1.0E-4)
   }
}
  • 3,5,7行目: Zero-Inflated Poisson分布をBUGSを表現する場合はこのようにします。ただし、いわゆるZero-Inflated Poisson分布の定義のtheta1-thetaはいれかえています。これは結果の解釈を分かりやすくするためです(thetaが来店する確率になる)。The BUGS Bookの11章を参照。パラメータが0のPoisson分布は0を生成する確率1の分布になりますのでこれでOKです。
  • 9-11行目: 回帰に相当する部分です。
  • 9行目: Bernoulli分布のパラメータは[0,1]の値を取るのでlogitをつけています。
  • 10,11行目: Poisson分布のパラメータは0以上なのでlogをつけています。
  • 14-21行目: 回帰係数(切片含む)の事前分布を無情報事前分布に設定しています。

キックするRコードは以下の通りです。

library(R2WinBUGS)
source("R2WBwrapper.R")

d <- read.delim("input/case1.txt", sep="\t")
d$age <- d$age/100
d <- transform(d,
   amount=amount/100,
   married=as.integer(family.structure==2),
   has.child=as.integer(family.structure==3)
)
N <- nrow(d)

Expvar <- c('age', 'gender', 'married', 'has.child', 'CM', 'set.flag', 'noodle.flag', 'ricebowl.flag', 'dessert.flag', 'sidemenu.flag', 'softdrink.flag', 'alcohol.flag')
X <- d[ , Expvar]
K <- ncol(X)

clear.data.param()
set.data("N", N)
set.data("Y.visit", d$num.visit)
set.data("Y.amount", d$amount)
set.data("K", K)
set.data("X", as.matrix(X))

set.param("alpha1", rnorm(1, 0, 0.1))
set.param("alpha2", rnorm(1, 0, 0.1))
set.param("alpha3", rnorm(1, 0, 0.1))
set.param("beta1", rnorm(K, 0, 0.1))
set.param("beta2", rnorm(K, 0, 0.1))
set.param("beta3", rnorm(K, 0, 0.1))
set.param("theta", NA)
set.param("lambda.visit", NA)
set.param("lambda.amount", NA)

post.bugs <- call.bugs(
   debug = FALSE,
   file = "model/model1.bugs",
   n.iter = 41000, n.burnin = 1000, n.thin = 40
)
post.list <- to.list(post.bugs)
post.mcmc <- to.mcmc(post.bugs)
save.image("output/model1.bugs.RData")
  • 5,7行目: 100で割ることでスケールをなるべく1のオーダーに近づけます。BUGSやStanでは収束を速めるためのtipsになります。
  • 8~9行目: family.structureは1,2,3の値を取るカテゴリ変数であり、それをダミー変数化しています。
  • 13行目: 説明変数に使う列名を指定しています。
  • 30~32行目: deterministicに決まるBUGS内の変数でもこのようにセットしておけばサンプリングしてくれます。

計算時間は並列化しないで約9時間(1chainあたり約3時間)でした。回帰に使う説明変数の数が10個を超えたあたりからMCMCサンプルの自己相関が高くなり、それを消すためにはBUGSではかなりiterationを増やさなくてはならないのが原因です。

次に個体差(個人差)なしモデルのStanコードは以下になります(model1.stan)。

data {
   int<lower=1> N;
   int<lower=0> Y_visit[N];
   int<lower=0> Y_amount[N];
   int<lower=1> K;
   matrix[N, K] X;
}
parameters {
   real alpha1;
   real alpha2;
   real alpha3;
   vector[K] beta1;
   vector[K] beta2;
   vector[K] beta3;
}
transformed parameters {
   vector<lower=0, upper=1>[N] theta;
   vector<lower=0>[N] lambda_visit;
   vector<lower=0>[N] lambda_amount;

   for (i in 1:N){
      theta[i] <- inv_logit(alpha1 + X[i]*beta1);
      lambda_visit[i] <- exp(alpha2 + X[i]*beta2);
      lambda_amount[i] <- exp(alpha3 + X[i]*beta3);
   }
}
model {
   for (i in 1:N) {
      # Y_visit
      if (Y_visit[i] == 0) {
         # Bernoulli(0|theta) + Bernoulli(1|theta) * Poisson(0|lambda)
         increment_log_prob(
            log_sum_exp(
               bernoulli_log(0, theta[i]),
               bernoulli_log(1, theta[i]) + poisson_log(0, lambda_visit[i])
            )
         );
      } else {
         # Bernoulli(1|theta) * Poisson(y|lambda)
         increment_log_prob(
            bernoulli_log(1, theta[i]) + poisson_log(Y_visit[i], lambda_visit[i])
         );
      }

      # Y_amount
      if (Y_amount[i] == 0) {
         # Bernoulli(0|theta) + Bernoulli(1|theta) * Poisson(0|lambda)
         increment_log_prob(
            log_sum_exp(
               bernoulli_log(0, theta[i]),
               bernoulli_log(1, theta[i]) + poisson_log(0, lambda_amount[i])
            )
         );
      } else {
         # Bernoulli(1|theta) * Poisson(y|lambda)
         increment_log_prob(
            bernoulli_log(1, theta[i]) + poisson_log(Y_amount[i], lambda_amount[i])
         );
      }
   }
}
generated quantities {
   vector<lower=0>[N] lambda_total_amount;
   lambda_total_amount <- lambda_visit .* lambda_amount;
}
  • 30~43行: StanでのZero-Inflated Poisson分布の表現になります。thetaからBernoulli分布で0/1を生成するところで離散パラメータを使うためsumming outしてincrement_log_prob()が必要になります(summing outについてはこの記事参照)。Stanのマニュアル(2.9.0)の「10.5 Zero-Inflated and Hurdle Models」の節に書かれています。しかしながらマニュアルと比べてbernoulli_log(0, theta)の中の0/1が逆になっていますが、これは今回はthetaを「夜間に来店する確率」にしたかったためです。

キックするRコードは以下になります。

library(rstan)

d <- read.delim("input/case1.txt", sep="\t")
d$age <- d$age/100
d <- transform(d,
   amount=amount/100,
   married=as.integer(family.structure==2),
   has.child=as.integer(family.structure==3)
)
N <- nrow(d)

Expvar <- c('age', 'gender', 'married', 'has.child', 'CM', 'set.flag', 'noodle.flag', 'ricebowl.flag', 'dessert.flag', 'sidemenu.flag', 'softdrink.flag', 'alcohol.flag')
X <- d[ , Expvar]
K <- ncol(X)

data <- list(
   N = N,
   Y_visit = d$num.visit,
   Y_amount = d$amount,
   K = K,
   X = X
)

fit <- stan(
   file='model/model1.stan',
   data=data,
   iter=1200,
   warmup=200,
   seed=123,
   thin=1,
   chains=3
)

save.image("output/model1.stan.Rdata")

BUGSの時とほとんど変わりません。計算時間は並列化しないで約20min(1chainあたり約7min)でした。

以下は結果です。BUGSとStanの結果はほぼ一致しましたので以下ではStanの結果を述べます。 theta, lambda_visit, lambda_amountについてはスライドを見てください。書籍で定義したところの総利用金額=来店回数×利用金額の予測があまりよくないのは来店回数(もしくは利用金額)が0か否か、すなわちthetaの予測が難しいことに起因していると思われます。参考までにスライドにはなかった回帰係数の結果を載せます。

meanse_meansdX2.5.X25.X50.X75.X97.5.n_effRhat
alpha1-4.19 0.01 0.47 -5.12 -4.49 -4.18 -3.87 -3.31 2484 1.00
alpha21.75 0.00 0.24 1.28 1.59 1.76 1.92 2.20 2605 1.00
alpha32.49 0.00 0.10 2.28 2.42 2.49 2.56 2.70 3000 1.00
beta1[1]2.66 0.01 0.53 1.62 2.31 2.66 3.04 3.68 3000 1.00
beta1[2]-0.97 0.00 0.11 -1.19 -1.04 -0.97 -0.90 -0.74 3000 1.00
beta1[3]-0.04 0.00 0.19 -0.41 -0.17 -0.04 0.10 0.33 3000 1.00
beta1[4]-0.19 0.00 0.13 -0.44 -0.27 -0.19 -0.10 0.06 3000 1.00
beta1[5]1.35 0.00 0.11 1.15 1.28 1.35 1.43 1.57 3000 1.00
beta1[6]-1.06 0.00 0.25 -1.54 -1.23 -1.06 -0.89 -0.59 3000 1.00
beta1[7]-0.28 0.00 0.11 -0.50 -0.36 -0.28 -0.20 -0.05 3000 1.00
beta1[8]-0.17 0.00 0.11 -0.39 -0.24 -0.16 -0.09 0.05 3000 1.00
beta1[9]0.03 0.00 0.21 -0.38 -0.12 0.04 0.17 0.42 3000 1.00
beta1[10]-0.08 0.00 0.17 -0.42 -0.20 -0.08 0.04 0.25 3000 1.00
beta1[11]-0.23 0.00 0.15 -0.54 -0.33 -0.22 -0.13 0.06 3000 1.00
beta1[12]0.18 0.00 0.14 -0.10 0.08 0.18 0.27 0.47 3000 1.00
beta2[1]2.57 0.01 0.33 1.95 2.34 2.58 2.80 3.20 3000 1.00
beta2[2]-0.37 0.00 0.07 -0.51 -0.41 -0.37 -0.32 -0.23 3000 1.00
beta2[3]-0.39 0.00 0.11 -0.60 -0.46 -0.39 -0.31 -0.17 3000 1.00
beta2[4]-0.13 0.00 0.07 -0.27 -0.18 -0.13 -0.09 -0.00 3000 1.00
beta2[5]-0.34 0.00 0.04 -0.42 -0.37 -0.34 -0.31 -0.25 3000 1.00
beta2[6]-0.10 0.00 0.12 -0.34 -0.19 -0.11 -0.02 0.15 3000 1.00
beta2[7]-0.30 0.00 0.06 -0.43 -0.35 -0.31 -0.26 -0.18 3000 1.00
beta2[8]0.04 0.00 0.06 -0.08 -0.00 0.04 0.09 0.16 3000 1.00
beta2[9]0.36 0.00 0.11 0.14 0.29 0.36 0.43 0.56 3000 1.00
beta2[10]-0.08 0.00 0.10 -0.27 -0.14 -0.08 -0.01 0.11 3000 1.00
beta2[11]0.16 0.00 0.09 -0.02 0.10 0.16 0.22 0.32 3000 1.00
beta2[12]0.19 0.00 0.07 0.05 0.14 0.19 0.24 0.33 3000 1.00
beta3[1]-1.43 0.00 0.16 -1.75 -1.53 -1.43 -1.32 -1.13 3000 1.00
beta3[2]0.42 0.00 0.03 0.36 0.40 0.42 0.44 0.48 3000 1.00
beta3[3]0.62 0.00 0.04 0.54 0.59 0.62 0.65 0.71 3000 1.00
beta3[4]0.39 0.00 0.03 0.32 0.36 0.39 0.41 0.46 3000 1.00
beta3[5]0.13 0.00 0.02 0.09 0.12 0.13 0.14 0.16 3000 1.00
beta3[6]-0.24 0.00 0.05 -0.34 -0.27 -0.24 -0.21 -0.15 3000 1.00
beta3[7]0.04 0.00 0.03 -0.02 0.02 0.04 0.06 0.10 3000 1.00
beta3[8]0.13 0.00 0.03 0.07 0.11 0.13 0.15 0.19 3000 1.00
beta3[9]-0.09 0.00 0.05 -0.19 -0.12 -0.09 -0.05 0.02 3000 1.00
beta3[10]0.23 0.00 0.04 0.15 0.20 0.23 0.26 0.31 3000 1.00
beta3[11]-0.11 0.00 0.04 -0.20 -0.13 -0.11 -0.08 -0.02 3000 1.00
beta3[12]0.30 0.00 0.03 0.24 0.28 0.30 0.33 0.37 3000 1.00

最後に参考までに個体差(個人差)ありモデルを紹介します。各パラメータに個人差の項を入れることで個人差がどれくらいの大きさなのかを見て現象の理解につなげます。個人差が大きいと、新しい人のデータの予測値の95%信用区間は広くなり、予測するのが難しいということになります。ちなみにBUGSではTrap66が出て止まりませんでしたのでStanでやりました。Stanコードは以下になります(model2.stan)。

data {
   int<lower=1> N;
   int<lower=0> Y_visit[N];
   int<lower=0> Y_amount[N];
   int<lower=1> K;
   matrix[N, K] X;
}
parameters {
   real alpha1;
   real alpha2;
   real alpha3;
   vector[K] beta1;
   vector[K] beta2;
   vector[K] beta3;
   vector[N] r_theta;
   vector[N] r_visit;
   vector[N] r_amount;
   real<lower=0> s_theta;
   real<lower=0> s_visit;
   real<lower=0> s_amount;
}
transformed parameters {
   vector<lower=0, upper=1>[N] theta;
   vector<lower=0>[N] lambda_visit;
   vector<lower=0>[N] lambda_amount;

   for (i in 1:N){
      theta[i] <- inv_logit(alpha1 + X[i]*beta1 + r_theta[i]);
      lambda_visit[i] <- exp(alpha2 + X[i]*beta2 + r_visit[i]);
      lambda_amount[i] <- exp(alpha3 + X[i]*beta3 + r_amount[i]);
   }
}
model {
   s_theta ~ gamma(10, 1);
   r_theta ~ normal(0, s_theta);
   r_visit ~ normal(0, s_visit);
   r_amount ~ normal(0, s_amount);

   for (i in 1:N) {
      # Y_visit
      if (Y_visit[i] == 0) {
         # Bernoulli(0|theta) + Bernoulli(1|theta) * Poisson(0|lambda)
         increment_log_prob(
            log_sum_exp(
               bernoulli_log(0, theta[i]),
               bernoulli_log(1, theta[i]) + poisson_log(0, lambda_visit[i])
            )
         );
      } else {
         # Bernoulli(1|theta) * Poisson(y|lambda)
         increment_log_prob(
            bernoulli_log(1, theta[i]) + poisson_log(Y_visit[i], lambda_visit[i])
         );
      }

      # Y_amount
      if (Y_amount[i] == 0) {
         # Bernoulli(0|theta) + Bernoulli(1|theta) * Poisson(0|lambda)
         increment_log_prob(
            log_sum_exp(
               bernoulli_log(0, theta[i]),
               bernoulli_log(1, theta[i]) + poisson_log(0, lambda_amount[i])
            )
         );
      } else {
         # Bernoulli(1|theta) * Poisson(y|lambda)
         increment_log_prob(
            bernoulli_log(1, theta[i]) + poisson_log(Y_amount[i], lambda_amount[i])
         );
      }
   }
}
generated quantities {
   vector<lower=0>[N] lambda_total_amount;
   lambda_total_amount <- lambda_visit .* lambda_amount;
}

34行目に注意してください。s_thetaには無情報事前分布ではなく、gamma分布でかなり情報の入った事前分布を与えています。これをはずすと収束しません。割とよく使われる対策だと思いますが、根本的にはモデルがあまりよくないのだと思われます。BUGSやStanではlogitの世界に色々な項を付け加えるのは基本的に相性が悪いです。理由はlogitが0付近での1動くのと10付近で1動くのではthetaの変化が全然違うためです。この差をなくすために0付近の動きと0を離れたところでの動きが同じぐらいになるようにうまくreparameterizationする方法がもしかしたら存在するのかもしれませんが、不勉強で今の時点では分かりません。

キックするRコードは省略します。計算時間は並列化しないで約3時間(1chainあたり約60min)でした。参考までにスライドにはなかった回帰係数の結果を載せます。個体差なしの時とかなり一致しています。残念ながらs_thetaの収束はあまりよくありません。

meanse_meansdX2.5.X25.X50.X75.X97.5.n_effRhat
alpha1-87.06 1.93 18.44 -124.94 -98.93 -85.87 -74.07 -53.57 91 1.01
alpha21.46 0.01 0.39 0.70 1.21 1.47 1.71 2.23 1662 1.00
alpha32.48 0.01 0.24 1.99 2.33 2.49 2.65 2.95 1170 1.00
beta1[1]63.66 1.39 18.75 30.71 50.36 62.62 75.71 103.19 181 1.01
beta1[2]-22.08 0.53 4.75 -32.13 -25.12 -21.75 -18.75 -13.70 79 1.03
beta1[3]-1.79 0.22 5.58 -13.04 -5.38 -1.72 1.79 9.23 633 1.00
beta1[4]-4.28 0.16 3.83 -12.13 -6.82 -4.05 -1.67 2.60 579 1.00
beta1[5]27.25 0.60 4.92 18.70 23.68 26.74 30.49 37.54 66 1.02
beta1[6]-22.97 0.55 8.20 -40.50 -28.22 -22.41 -17.28 -8.08 221 1.01
beta1[7]-7.21 0.22 3.87 -15.38 -9.68 -7.07 -4.57 0.08 323 1.00
beta1[8]-3.38 0.13 3.38 -10.32 -5.51 -3.29 -1.11 3.10 637 1.00
beta1[9]1.47 0.25 6.55 -11.55 -2.97 1.43 5.75 14.73 696 1.00
beta1[10]-2.20 0.21 5.47 -13.16 -5.93 -1.90 1.43 8.38 674 1.00
beta1[11]-4.91 0.16 4.62 -14.33 -7.87 -4.74 -1.77 3.67 799 1.00
beta1[12]4.02 0.20 4.63 -4.83 0.90 3.96 7.02 13.50 554 1.01
beta2[1]2.62 0.01 0.55 1.53 2.24 2.62 3.00 3.67 1836 1.00
beta2[2]-0.30 0.00 0.12 -0.53 -0.38 -0.30 -0.22 -0.08 2248 1.00
beta2[3]-0.35 0.00 0.18 -0.71 -0.47 -0.35 -0.23 0.01 2300 1.00
beta2[4]-0.12 0.00 0.12 -0.34 -0.20 -0.12 -0.04 0.11 1983 1.00
beta2[5]-0.34 0.00 0.07 -0.49 -0.38 -0.33 -0.29 -0.19 2128 1.00
beta2[6]-0.11 0.00 0.20 -0.48 -0.26 -0.12 0.02 0.28 1777 1.00
beta2[7]-0.21 0.00 0.11 -0.42 -0.28 -0.21 -0.13 0.01 1832 1.00
beta2[8]0.02 0.00 0.11 -0.20 -0.05 0.02 0.09 0.23 2052 1.00
beta2[9]0.26 0.00 0.20 -0.14 0.13 0.27 0.40 0.65 2056 1.00
beta2[10]0.03 0.00 0.16 -0.28 -0.07 0.03 0.14 0.36 2288 1.00
beta2[11]0.10 0.00 0.16 -0.21 -0.00 0.10 0.21 0.40 1843 1.00
beta2[12]0.15 0.00 0.13 -0.11 0.06 0.15 0.24 0.41 2108 1.00
beta3[1]-1.34 0.01 0.35 -2.03 -1.58 -1.34 -1.09 -0.65 1157 1.00
beta3[2]0.38 0.00 0.07 0.25 0.34 0.38 0.43 0.52 1522 1.00
beta3[3]0.63 0.00 0.11 0.41 0.56 0.63 0.70 0.84 1353 1.00
beta3[4]0.45 0.00 0.08 0.30 0.40 0.44 0.50 0.60 1381 1.00
beta3[5]0.10 0.00 0.04 0.01 0.07 0.10 0.13 0.19 1398 1.00
beta3[6]-0.22 0.00 0.12 -0.47 -0.30 -0.22 -0.14 0.03 1457 1.00
beta3[7]-0.00 0.00 0.07 -0.14 -0.05 -0.00 0.05 0.14 1409 1.00
beta3[8]0.05 0.00 0.07 -0.08 0.00 0.05 0.10 0.18 1341 1.00
beta3[9]-0.01 0.00 0.13 -0.25 -0.10 -0.01 0.08 0.24 1672 1.00
beta3[10]0.15 0.00 0.10 -0.04 0.09 0.15 0.22 0.35 1346 1.00
beta3[11]-0.11 0.00 0.10 -0.29 -0.18 -0.11 -0.04 0.09 1813 1.00
beta3[12]0.24 0.00 0.08 0.07 0.18 0.24 0.29 0.40 1088 1.00
s_theta37.44 0.84 5.38 27.34 33.56 37.13 41.34 48.12 41 1.04
s_visit0.69 0.00 0.05 0.61 0.66 0.69 0.72 0.79 197 1.00
s_amount0.48 0.00 0.03 0.43 0.47 0.48 0.50 0.54 325 1.01