読者です 読者をやめる 読者になる 読者になる

StatModeling Memorandum

StanとRでベイズ統計モデリングします. たまに書評.

MCMCサンプルを{dplyr}で操る

R

RからStanやJAGSを実行して得られるMCMCサンプルは、一般的に iterationの数×chainの数×パラメータの次元 のようなオブジェクトとなっており、凝った操作をしようとするとかなりややこしいです。

StanとRでベイズ統計モデリング (Wonderful R)』のなかでは、複雑なデータ加工部分は場合によりけりなので深入りしないで、GitHub上でソースコードを提供しています。そこでは、ユーザが新しく覚えることをなるべく少なくするため、Rの標準的な関数であるapply関数群を使っていろいろ算出しています。しかし、apply関数群は慣れていない人には習得しづらい欠点があります。

一方で、Rのデータ加工パッケージとして、%>%によるパイプ処理・{dplyr}パッケージ・{tidyr}パッケージがここ最近よく使われており、僕も重い腰を上げてやっと使い始めたのですが、これが凄く使いやすい。%>%selectfiltermutategroup_bysummarize*_joingatherspreadだけをまずは覚えればほとんど不自由しませんでした。これらがないともう他の言語に移れないレベルです。これらのパッケージの練習のおかげで、ややこしいMCMCサンプルの処理についても、こんな感じでやれば毎回ウンウン唸らずに統一的に操作できそうかなぁ、というところまで来ましたので簡単にメモします。

* * *

手始めに以下の図を描いてみます。

この図はパラメータごとにMCMCサンプルの中央値と95%CIを表示した図です。{ggmcmc}パッケージや{bayesplot}パッケージに含まれる関数を使うと一撃で描くこともできます。しかし、練習のため自分で算出して作図します。

library(rstan)
library(ggmcmc)
library(dplyr)

data <- list(J=8, y=c(28,  8, -3,  7, -1,  1, 18, 12), sigma=c(15, 10, 16, 11,  9, 11, 10, 18))
model_code <- readr::read_file(url('https://raw.githubusercontent.com/wiki/stan-dev/rstan/8schools.stan'))
fit <- stan(model_code=model_code, data=data, seed=1234)

d_mcmc <- ggs(fit)
d_qua <- d_mcmc %>%
  filter(grepl('^theta\\[\\d+\\]$', Parameter)) %>%
  group_by(Parameter) %>%
  summarize(`2.5%` = quantile(value, probs=.025),
            `50%`  = quantile(value, probs=.5),
            `97.5%`= quantile(value, probs=.975))

p <- ggplot() +
  geom_pointrange(data=d_qua, mapping=aes(x=forcats::fct_rev(Parameter), y=`50%`, ymin=`2.5%`, ymax=`97.5%`)) +
  coord_flip() +
  labs(x='Parameter', y='Value')
ggsave(p, file='fig1.png', dpi=300, w=4, h=3)
  • 6行目:Web上から8schools.stanを読み込んで文字列としています。RStanの公式ページの例題で使われているモデルファイルです。
  • 7行目:stan関数はmodel_code引数でモデルを書いた文字列も指定できます(本来はファイル名を直接指定できればよかったのですがよくわかりませんでした)。
  • 9行目:{ggmcmc}パッケージのggs関数でtidyなデータにしておきます。tidyなデータについては西原さんの記事を参照。{ggmcmc}パッケージに含まれる関数を使うとd_mcmcからいろいろな図が描けます。詳しくは『StanとRでベイズ統計モデリング (Wonderful R)』の4章に書きましたので読んでいただけるとうれしいです。

なお、d_mcmcは以下のようなデータフレームになります。

> d_mcmc
# A tibble: 72,000 × 4
   Iteration Chain Parameter   value
       <dbl> <int>    <fctr>   <dbl>
1          1     1        mu -1.3476
2          2     1        mu -0.9601
3          3     1        mu  7.0919
4          4     1        mu 15.0782
5          5     1        mu 20.0110
6          6     1        mu 20.4483
7          7     1        mu 13.0249
8          8     1        mu 11.8232
9          9     1        mu 15.9213
10        10     1        mu 17.9294
# ... with 71,990 more rows
  • 11行目:まずはモデルに含まれるtheta[数字]というパラメータだけ残しています。grepl関数でパラメータ名がマッチするか判定する際に正規表現を使う必要があります。ここが正規表現に慣れていない人は少し厳しいかもしれません。
  • 12~15行目:{dplyr}パッケージの典型的な使い方です。Parameter列ごとに要約量を算出します。列名が数字で始まる場合はバッククォートで囲む必要があります。1つずつ分位点を算出するのではなく、do関数で一行で算出する方法もあるのですが、分かりにくく、現状issueとして検討中のようです(ここここ)。
  • 18行目:ggplot2coord_flipすると下から上に向かってfactorが並びますので、{forcats}パッケージのfct_rev関数で逆順にしています。

* * *

次に以下の図を描いてみます。久保本11章の図に相当します。

library(rstan)
library(ggmcmc)
library(dplyr)

Y <- read.csv(url('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap12/input/data-kubo11a.txt'))$Y
I <- length(Y)
d <- data.frame(X=1:I, Y=Y)
data <- list(I=I, Y=Y)
model_code <- readr::read_file(url('https://raw.githubusercontent.com/MatsuuraKentaro/RStanBook/master/chap12/model/model12-11.stan'))
fit <- stan(model_code=model_code, data=data, seed=1234)

d_mcmc <- ggs(fit)
d_qua <- d_mcmc %>%
  filter(grepl('^Y_mean\\[\\d+\\]$', Parameter)) %>%
  tidyr::separate(Parameter, into=c('Parameter', 'x'), sep='[\\[\\]]') %>%
  mutate(x=as.integer(x)) %>%
  group_by(Parameter, x) %>%
  summarize(`2.5%` = quantile(value, probs=.025),
            `10%`  = quantile(value, probs=.1),
            `50%`  = quantile(value, probs=.5),
            `90%`  = quantile(value, probs=.9),
            `97.5%`= quantile(value, probs=.975))

p <- ggplot() +
  geom_ribbon(data=d_qua, mapping=aes(x=x, ymin=`2.5%`, ymax=`97.5%`), alpha=1/6) +
  geom_ribbon(data=d_qua, mapping=aes(x=x, ymin=`10%`,  ymax=`90%`),   alpha=2/6) +
  geom_line(data=d_qua, mapping=aes(x=x, y=`50%`)) +
  geom_point(data=d, aes(x=X, y=Y), shape=1, size=2) +
  labs(x='i', y='Y[i]') +
  ylim(0, 22)
ggsave(p, file='fig2.png', dpi=300, w=4, h=3)
  • 15行目:{tidyr}パッケージのseparate関数を使って、Y_mean[20]Y_mean20という2つの列に分解しています。
  • 17行目:あとは集計の単位であるgroup_byの単位が場面によって多少変わるぐらいで、特に悩まずに色々な量が算出できます。

この記事ではMCMCChainごとに何かを算出することは取り上げませんでしたが、ggs関数で作ったd_mcmcChain列も含んでいますので自由自在です。