StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

Bayesian GPLVMをStanで実装してみた

この記事の続きです。PRML下の12章に出てくるOil Flowのデータ(データ点1000個×特徴量12個)に対してBayesian GPLVMで2次元(または3次元)の潜在変数空間にマッピングして綺麗に分離されるか見てみます。

まずはPRMLにもあるように普通の主成分分析でやると以下になります。綺麗には分離されません。

f:id:StatModeling:20201106181957p:plain

次にBayesian GPLVMでやってみます。Stanコードは以下になります。

  • 2~4行目: NKDはそれぞれ、データ点の数・特徴量の数・最終的に落とし込む潜在空間の次元です。
  • 14行目: 潜在変数です。
  • 15行目: カーネルに含まれるパラメータです。僕が理解したところだと特徴量ごとにガウス過程が存在するのでKごとに異なる値を持つようにしています。→ 2017.07.02追記 Kごとに異なる値にするのではなく、1つだけにし、スケーリングしてから適用することで情報を圧縮させる方がふつうのようです。詳しくはMLPシリーズ『ガウス過程と機械学習』参照。
  • 19行目: 同様に特徴量ごとに分散共分散行列があります。
  • 20・22~23行目: カーネルの定義で効率的な行列演算をするため、matrix型をvector型の配列に持ち替えます。
  • 24~28行目: カーネルの定義です。ここではGaussian+bias+linear+white noiseにしました。カーネルについてはGP summser school 2015のKernel Designの講義資料 (pdf)The Kernel Cookbookなどを参照してください。
  • 29行目: 潜在空間に対する縛りです。代わりにparametersブロックでlowerupperを定めてもOKです。
  • 30~31行目: カーネルに含まれるパラメータの事前分布です。軽くしばっています。ある範囲の一様分布にするなど他にも色々考えられると思います。
  • 32~33行目: ガウス過程の部分です。この書き方だと各特徴量は独立になっていますが、さらに特徴量間の相関を考慮したモデル(例えばCoregionalized Regression Model)もあります。ここではすでに計算量が膨大なので独立としました。

Rからの実行方法は以下です。計算が重いのでADVIを使いました。

  • 5行目: データがmatlabのファイルで与えられていたのでそれを読み込んでいます。結果はリストになります。
  • 8行目: 潜在変数の次元Dを与えています。最終的に2次元にプロットしたいです。余裕をもって次元を設定し、主成分分析のように寄与の大きい次元トップ2だけを抽出することで2次元に射影する方法もあるようです。ここでははじめから2次元の空間とします。また3次元の空間も試して3Dプロットしてみます。
  • 10~16行目: まずはふつうの主成分分析しています。前述の図はここで出力しているresult-pca.pngになります。またGPLVMを実行する際の初期値にする意味もあります。
  • 18~28行目: Bayesian GPLVMを実行しています。21行目ではinitオプションでPCAの結果を初期値として設定しています。22行目では時間短縮のためetaを求めないで1で与えています(手元のモデルとデータではチューニングの結果がいつもeta = 1だったという理由があります)。

計算時間はAWS EC2のc4.xlargeを使っても3時間半ぐらいかかりました。かなり遅いです。ADVIの代わりにStanのoptimizing関数を使って推定した方がよいかもしれません。また汎用的な確率的プログラミング言語ではガウス過程に特化した専用ライブラリにはかないません。Stanのモデルはユーザの問題にあわせた拡張が簡単なので、その点で使う価値はあると思います。特にモデルを拡張する予定がないならば、もしくはデータが巨大ならば、Pythonガウス過程に特化したライブラリであるGPyなどの使い方を学ぶべきと思います。

結果

D = 2の場合

潜在変数xの乱数サンプルの中央値を使って2次元の散布図を描くと以下になりました(result-bgplvm.png)。それなりに綺麗に分離していると思います。乱数の種の影響も見ましたが、おおよそ似たような結果になりました。

f:id:StatModeling:20201106181953p:plain

D = 3の場合

同様に3次元の散布図を描きました。2次元の場合より若干綺麗に分離していそうです。

f:id:StatModeling:20201106181948p:plain