このシリーズははじめの2ステップ(NB→UM→LDA)がとっつきにくいですがそこまで理解すれば後のモデルの拡張はそんなに難しくは感じませんでした。そのためNBから順にしっかり理解することが重要と思います。またNBとUMは文書のトピックが与えられているかそうでないかの違いしかなく、BUGSコードは全く同一のまま動きます(Stanでは離散パラメータを含みますので多少面倒になります)。今回はNBの分かりやすい説明を試みたのち、実際にStanでの実装と結果を見ていきたいと思います。
はじめにこの記事の表記から。以下になっています。
右2列は定数については数値を、そうでないものについてはR内の変数名を書いています。与えられているデータ(前回の記事の data1
のw.1
)は以下の図のようになっています。
文書が1-100(M)まであり、その各文書に144(V)種類の単語のいずれかが出現しています。このデータをどうモデル化するか。NBの場合は以下のようなグラフィカルモデルに従って単語が出現するとみなす、というわけです。
グラフィカルモデルはモデルに登場する各変数間の関係を表したものです。●に白文字は与えられたデータ、○に黒文字は推定すべき変数、■に白文字は与えられたハイパーパラメータです。矢印が値を決める関係を表します。大きな四角はプレートと呼ばれ、そこに繰り返しがあることを示しています。
グラフィカルモデルだけでは分かりにくいと感じたので、以下の吹き出しの順に説明していきます。
①
ここではハイパーパラメータからディリクレ分布に従って『1つだけ』が生成されます。このは以下のような『文書ごとの』トピックを決める、いびつなK面サイコロです。
は「トピック分布」と呼ばれます。またディリクレ分布のパラメータであるはK次元のベクトルです。サイコロの形が丸っぽくなるのか(≒どのトピックも同じような出現確率になる)、それとも凄くいびつな形になるのか(≒一部のトピックだけ出やすくて、残りはほとんどでなくなる; 将棋の駒を想像してみて下さい)を調節するパラメータになります。だと丸っぽくなる可能性が高くなり、だといびつな形になる可能性が高くなり、だとどの形にも均等になりやすいと言えます(一様分布の多変数版に相当)。どのトピックが出やすいかは前もってわからないので今回はを採用します。ディリクレ分布についてはflyioさんのブログのこの記事がイメージしやすいです。
②
ここではからトピックを選ぶ過程です。いびつなK面のサイコロをふって『文書ごとに』トピックを決めていきます。全て同じサイコロで決めることに留意。
③
ここではハイパーパラメータからディリクレ分布に従ってトピックの数(K)だけが生成されます。このは以下のような出現する単語を決める、いびつなV面のサイコロです。
は「単語分布」と呼ばれます。文書内の単語は一般的に一部の単語だけが出やすい状況が多いのでを使っています。持橋先生もこれぐらいの値ではじめてみようと統数研の講座でおっしゃっていましたので従います。もっと極端だと分かっている場合にはとかにしましょう。
④
ここでは各文書について、③で選ばれたトピックに従って、その色のV面サイコロを振り続けて単語を決めていきます。文書が変わればトピックも変わり、サイコロも変わります。
まとめますと以下の図のようになります。
以上の①~④のプロセスで文書x単語が生成されたと考えることになります。お気づきかもしれませんが、このモデルでは単語の順番や単語間の関係は単語の出現をモデル化するうえで大切ではないと考えて切り捨てていることに相当します。これでOKです。モデルとは「ある知りたいことを理解するために構築するもので現象を再現するのに本質的な仮定をいくつかあわせた仮説」であると思っています。同じ現象を見ていても、知りたいことが異なればモデルは変わります。モデルは現象を再現する十分条件を目指します。その一方で条件をどこまでそぎ落としてシンプルにできるか(必要十分条件に近づけるか)、どこまで精度が上がるか(予測があたるか)が大切で、この相反する二つをいかに両立させるかがモデル化の腕の見せ所となっています。物理系の分野だと前者に注力し、情報系の分野だと後者に注力する傾向があるように感じています。
さて次にStanでの実装にうつります。
data { int<lower=1> K; # num topics int<lower=1> M; # num docs int<lower=1> V; # num words int<lower=1> N; # total word instances int<lower=1,upper=K> Z[M]; # topic for doc m int<lower=1,upper=V> W[N]; # word n int<lower=1,upper=N> Offset[M,2]; # range of word index per doc # hyperparameter vector<lower=0>[K] Alpha; # topic prior vector<lower=0>[V] Beta; # word prior } parameters { simplex[K] theta; # topic prevalence simplex[V] phi[K]; # word dist for topic k } model { # prior theta ~ dirichlet(Alpha); for (k in 1:K) phi[k] ~ dirichlet(Beta); # likelihood for (m in 1:M){ Z[m] ~ categorical(theta); for (n in Offset[m,1]:Offset[m,2]) W[n] ~ categorical(phi[Z[m]]); } }
- 14~15行目:
K
面サイコロtheta
を1個、V
面サイコロphi
をK
個、定義しています。サイコロ(目の出る確率の合計が1)はsimplex
で定義します。 - 25行目:サイコロを振って文書のトピックを決めていることに相当します。
- 26行目:ある文書内の単語だけ繰り返したいためにOffsetを使っています。前回の記事を参照してください。
- 27行目:文書のトピック
Z[m]
の色を持つV面サイコロphi
を振って出現する単語を決めています。結局StanではデータからK
面サイコロtheta
(1個)、V
面サイコロphi
(K
個)の形を推定することになります。
これをキックするRのコードは以下になります。読み込んでいるデータは前回の記事を参照してください。
library(rstan) load("input/201402_data1.RData") data <- list( K=K, M=M, V=V, N=N.1, Z=z.for.doc, W=w.1$Word, Offset=offset.1, Alpha=rep(1, K), Beta=rep(0.5, V) ) fit <- stan( file='model/NB.stan', data=data, iter=1000, chains=1 )
21行目にchains=1を指定しているのは次回のUMにあわせるためです。次回に詳しく述べます。
結果は以下になります。
まずはthetaの推定値から(左: data1
, 右: data2
)。
点はMCMCサンプルの中央値で範囲は80%信頼区間です。横軸はトピックのインデックス・縦軸は確率です。サンプルデータはLDAを模したデータを作り方をしたのでNBのthetaは真の値と直接比較はできません。文書あたりの単語数が少ない場合でも多い場合とほぼ同様の結果になっています。これはトピックが与えられているので納得できる結果です。次にphiの推定値(左: data1, 右: data2)。
点はMCMCサンプルの中央値で範囲は80%信頼区間です。横軸は単語のインデックス・縦軸は確率です。真の値を黒の横棒で表しています。真の頻度が低いところを高めに推定し、真の頻度が高いところを低めに推定しているという結果になりました。はじめはphiのハイパーパラメータであるの(データを作成した)真の値がで一部の単語の頻度が高くなるように設定したのに対し、推定する際に与えた値がだったのでそれにひっぱられたのかなと思っていましたが、真の値を事前分布として与えても同じ結果になりました。NBでは1文書に1トピックだけ割り当てられるので、今回のサンプルデータをそれにあてはめようとするとトピック内の単語頻度が平均化されてしまうと見るのが正しそうです。
ちなみに無情報事前分布を設定してデータだけからハイパーパラメータも決めるフルベイズをやりたいところですが、トピックモデル全般的にパラメータが多めでデータが少なめな状況が多く、フルベイズはことごとく失敗しました。
最後にモデルのパラメータ数(Stanで出力される数です。ホントはsimplexの制限があるので実質的な数は少し減ります)や推定にかかった計算時間などは以下の通りです。
項目 | 文書あたりの単語数 | NB | UM | LDA | LDA(Freq) | PAM | GaP |
---|---|---|---|---|---|---|---|
パラメータ数 | 少/多 | 1450 | ? | ? | ? | ? | ? |
計算時間 | 少 | 5.5m | ? | ? | ? | ? | ? |
収束具合 | 少 | ○ | ? | ? | ? | ? | ? |
lp__ | 少 | -13892 | ? | ? | ? | ? | ? |
計算時間 | 多 | 30.8m | ? | ? | ? | ? | ? |
収束具合 | 多 | ○ | ? | ? | ? | ? | ? |
lp__ | 多 | -74154 | ? | ? | ? | ? | ? |