読者です 読者をやめる 読者になる 読者になる

StatModeling Memorandum

StanとRでベイズ統計モデリングします. たまに書評.

infer.netの例題シリーズ1 Bayes Point Machine

BUGS

BayesPointMachine(以下BPM)の元論文はこちら(R.Herbrich, T.Graepel, and C.Campbell, BayesPointMachines, JMLR, 2001)。

BPMはクラス分類を行うアルゴリズムSVMに似ています。SVMはマージンを最大にするwを求めますが、BayesPointMachineはモデルの尤度に応じたwの分布を求めます。あるクラスに属する確率の分布も求まります。

対応するinfer.netのページはこちら。例題で与えられたデータを以下のようなテーブルにしました。

Yincomeage
16338
01623
12840
15527
02218
02040
NA5836
NA1824
NA2237

ここで、Yは購入したか否かを表すクラス、incomeは収入、ageは年齢です。incomeageからYを予測せよ、という問題です。最初の6人がTrainingデータ、最後の3人はTestデータです。今回はinfer.netのサンプルコードに忠実に線形判別のモデル(SVMで線形カーネルを用いた場合に相当)をBUGSで実装しました。

BUGSコードは以下になります。たった10行です。

model {
   for (i in 1:N) {
      Y[i] ~ dbern(p[i])
      p[i] <- phi(x[i])
      x[i] ~ dnorm(mu[i], 10)
      mu[i] <- inprod(w[], Personal[i,])
   }
   w[1:N.w] ~ dmnorm(Zero[], tau[,])
   tau[1:N.w, 1:N.w] ~ dwish(R[,], N.w)
}
  • 3-4行目: infer.netでは不等号で0,1に変換していますが、BUGSではlogit()phi()を使うのが安定です。step()とか使うと制限がきつすぎるためか収束しません。
  • 8-9行目: infer.netのコードに習ってwの事前分布を多変量正規分布にしています。逆ウィシャート分布(dwish)の自由度パラメータはデータの次元以上である制限があります。値が低いほど無情報です。収束しない場合は値を増やしてもいいかもしれません。flyioさんの大変ためになるブログ記事を参照。

実行するRコードは以下のようになります。

source("R2WBwrapper.R")
d <- read.delim("input/data.txt", as.is=T, header=T, sep="\t")
N <- nrow(d)
Personal.m <- as.matrix(cbind(scale(d[,-1]), Intercept=rep(1, length=N)))
N.w <- ncol(Personal.m)

clear.data.param()
set.data("N", N)
set.data("Y", d$Y)
set.data("N.w", N.w)
set.data("Zero", rep(0, N.w))
set.data("R", diag(rep(0.5, N.w), N.w))
set.data("Personal", Personal.m)

set.param("x", d$Y-0.5, save=F)
set.param("w", rep(0, N.w))
set.param("tau", diag(rep(10, N.w), N.w))
set.param("p", NA)

post.bugs <- call.bugs(
   file = "model/model.bugs",
   n.iter = 62000, n.burnin = 2000, n.thin = 100
)
post.list <- to.list(post.bugs)
post.mcmc <- to.mcmc(post.bugs)
save.image("output/result.RData")

計算時間はおよそ15秒でした。結果は以下の通り。wは収束はしたもののsdが大きいですね。pはいい感じです。

meansd2.5%25%50%75%97.5%Rhatn.eff
w[1]1.982.42-0.120.561.172.448.751.03300
w[2]0.811.11-0.510.160.561.093.961.01630
w[3]0.121.05-1.51-0.360.030.472.571.031800
tau[1,1]4.043.870.241.372.895.4114.571.001800
tau[1,2]-1.263.04-8.11-2.80-1.010.384.361.001800
tau[1,3]-0.133.10-6.52-1.78-0.211.576.481.001800
tau[2,1]-1.263.04-8.11-2.80-1.010.384.361.001800
tau[2,2]6.494.810.762.985.338.7418.881.00680
tau[2,3]0.013.50-7.34-1.97-0.051.987.131.001800
tau[3,1]-0.133.10-6.52-1.78-0.211.576.481.001800
tau[3,2]0.013.50-7.34-1.97-0.051.987.131.001800
tau[3,3]7.045.150.823.185.809.5120.481.00980
p[1]0.900.170.400.860.991.001.001.011600
p[2]0.140.180.000.000.060.230.631.001400
p[3]0.620.250.120.430.630.821.001.01420
p[4]0.780.230.240.630.860.991.001.001300
p[5]0.160.200.000.000.060.260.701.00740
p[6]0.370.260.000.150.340.580.901.001300
p[7]0.870.190.330.800.971.001.001.001500
p[8]0.160.190.000.000.080.280.651.001300
p[9]0.360.250.000.140.340.550.871.04340
deviance4.132.490.472.283.775.589.741.02330

Testデータである最後の3人の購入する確率は以下のように算出されました(BUGSは分布のmedianを取りました)。

infer.netBUGS
Test10.960.97
Test20.160.08
Test30.290.34

さらに識別能力を上げるために2通りの方法があります。1つ目の方法は先ほどのBUGSコードにおいて、wの平均をゼロベクトルではなく、無情報な事前分布を導入することです。以下になります(変更部分だけ)。

w[1:N.w] ~ dmnorm(mn[], tau[,])
mn[1:N.w] ~ dmnorm(Zero[], Non.informative.tau[,])

実行するRコードも2行追加になります。

set.data("Non.informative.tau", diag(rep(1.0E-4, N.w), N.w))
set.param("mn", rep(0, N.w))

しかしながら今回はこれではデータ不足のため収束しませんでした。

2つ目の方法は線形カーネルではなくて他のカーネルを使うことです。MicrosoftのForumに2011/06/03に質問している方がいまして、詳しそうな方の返答によると以下になります。

minka replied on 05-04-2010 7:32 AM Infer.NET does not provide kernels, however you can simulate kernels using random feature expansion (Rahimi & Recht, "Random features for large-scale kernel machines", NIPS 2007). This technique replaces the kernel trick by mapping the data into a randomized feature space with finite dimension, such that the inner product of two randomly mapped data points is approximately the same as the value of a kernel function evaluated at the two data points. As a result, you obtain a random feature based classifier approximately equivalent to the kernelized classifier. Infer.NET does not automatically perform random feature expansion so you will have to write the code to generate these features, but the code should be quite simple.

とのことです。僕は機械学習アルゴリズムには疎いのでこのあたりで一時退却としました。