StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

階層ベイズモデルとWAIC

この記事では階層ベイズモデルの場合のWAICとは何か、またその場合のWAICの高速な算出方法について書きます。

背景

以下の2つの資料を参照してください。[1]に二種類の実装が載っています。[2]に明快な理論的補足が載っています。

モデル1

資料[1]にあるモデルを扱います。すなわち、

f:id:StatModeling:20201106183418p:plain

ここで N は人数、 n は人のインデックスです。 r[n] は個人差を表す値になります。このモデルにおいては r[n] を解析的に積分消去することができて、負の二項分布を使う以下のモデル式と等価になります。

f:id:StatModeling:20201106183423p:plain

ここでは予測として(WAICとして)2通り考えてみましょう。 以降では事後分布による平均を  \mathbb{E}[\,] 、分散を  \mathbb{V}[\,] と書くことにします。

(1)  r[n] を持つ n が、追加で新しく1つのサンプルを得る場合

この場合には新しいデータ y の予測分布は以下になります。

f:id:StatModeling:20201106183428p:plain

WAICは n ごとに算出され、以下になります。

f:id:StatModeling:20201106183432p:plain

(2) 別の新しい人が新しく1つのサンプルを得る場合

この場合には次のモデルを考えていることに相当します。

f:id:StatModeling:20201106183436p:plain

そして、新しいデータ y の予測分布は以下になります。

f:id:StatModeling:20201106183439p:plain

WAICは以下になります。

f:id:StatModeling:20201106183443p:plain

ソースコード

 n ごとにWAICを算出することや、WAIC内の和(シグマ)はR側で処理します。

(1)のStanコード

(2)に対応する負の二項分布を使ったStanコード

(2)のStanコード

数値積分をR側かStan側のどちらかで実行する必要があります。資料[1]ではR側で行っており、これが多大な時間がかかる原因となっています。ここでは合成シンプソン公式(とlog_sum_exp関数)を使ってStan側で数値積分をして高速化します。

これはこちらのコードをメモ化によって高速化したものになっています。どちらのコードでも6~17行目でシンプソンの公式を使って数値積分をしています。

Rコード

結果

waic1_byG waic2 waic3
2.332 3.244 2.841 ... 2.987 3.14 3.143

計算時間は(1)の場合は、Surface Pro 3で1chainあたり5秒ぐらいです。(3)の場合でもメモ化がバッチリ効いて1chainあたり12秒ぐらいです。

waic1_byGにおいて、r[n]の大きなnr[n]の小さなnと比べて、ガンマ分布の裾部分の確率密度に由来する可能性が高く、(1)の予測が悪くなる(WAICが大きくなる)ことが予想されるでしょう。ここでは図示しませんが調べるとそうなっています。

waic2waic3は理論的に一致するはずですが、Stanコードの違いがMCMCサンプルの違いになり、その影響でわずかにズレます。

なお、資料[1]のp.47-48のソースコードだと(1)の場合のWAICを n ごとに算出したあとに、それらの和をとって N で割った値になります。WAICの和は「各 n が、追加でそれぞれ新しく1つのサンプルを得る場合」の予測に対応します。それを N で割った値が対応する予測はよく分かりません。

また、WAICはMCMCサンプルによって値が変わるので、乱数の種の影響をわずかにうけることに注意です。

モデル2

資料[2]にあるモデルと似たモデルを扱います。すなわち、

f:id:StatModeling:20201106183448p:plain

ここで G はクラス数(グループ数)、 g はそのインデックスです。 N は人数、 n はそのインデックスです。 N2G[n]  n が所属している g を返します。 ここでは予測として(WAICとして)3通り考えてみましょう。

(1) あるクラス g に、新しく1人が加わる場合

この場合には新しいデータ y の予測分布は以下になります。

f:id:StatModeling:20201106183453p:plain

WAICは g ごとに算出され、以下になります。

f:id:StatModeling:20201106183457p:plain

ここで G2N[g] はクラス g に含まれる n のインデックスすべてです。

(2) 別の新しいクラスがまるごとできる場合

この場合には新しいクラス全体のデータ y^n の予測分布は以下になります。

f:id:StatModeling:20201106183501p:plain

 (\,)^n の記法は資料[2]を参照してください。

WAICは以下になります。

f:id:StatModeling:20201106183505p:plain

(3) 別の新しいクラスができて、新しく1人が加わる場合

この場合には新しいデータ y の予測分布は以下になります。

f:id:StatModeling:20201106183509p:plain

WAICは以下になります。

f:id:StatModeling:20201106183513p:plain

ソースコード

(1)のStanコード

(2)のStanコード

グループ差や個人差が正規分布から生成される場合には、-5SDから+5SDぐらいまでを数値積分すればかなりよい近似になります。

(3)のStanコード

これはこちらのコードYによって変わらない部分をはじめに計算して保持しておいて、Yによって変わる部分だけをループで計算することで高速化したものになっています。

Rコード

結果

waic1_byG waic2 waic3
2.537 2.390 2.750 ... 2.496 142.1 3.679

計算時間は(1)(2)(3)の場合がそれぞれ、1chainあたり0.4秒・1秒・12秒ぐらいです。こちらはメモ化ほど高速化が効きませんが、それでも高速化しない場合と比べると1.5倍ぐらい早くなっています。

waic1_byGにおいて、クラスあたりの人数(NbyG)の多いクラスの方がWAICは小さくなるかなと思ったのですが、そこまできれいな関係ではありませんでした。ただし、人数が5人のクラスはWAICは目に見えて高くなっています。

あわせて読みたい

statmodeling.hatenablog.com statmodeling.hatenablog.com