StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

ガウス過程シリーズ 3 クラス分類(PRML下 Fig 6.12)

今回はGaussian Processで2値クラス分類を行います。2値なのでlogistic linkをかませばOKです。しかしながら、高速化ができなくなります。Stan manualの中にも登場しますがinfer.netの例題の中の「Short Examples: Gaussian Process classifier」にも登場します。

さてデータですが、PRML下のFig 6.12のデータを用います。以下のようなデータです。

XYclass
1.2089850.4214480
0.504542-0.285731
0.6305681.0547120
1.0563640.6018730
-0.490114-0.8411221
1.8470752.3623220
-0.2797030.7531961
1.953357-0.7466320

Stanコードは以下になります。

data {
   int D;
   int N;
   vector[D] X[N];
   int<lower=0, upper=1> Class[N];
   int N_new;
   vector[D] X_new[N_new];
   real Eta_sq;
   real Rho_sq;
   real Sigma_sq;
}

transformed data {
   vector[N] Mu;
   for (i in 1:N) Mu[i] <- 0;
}

parameters {
   vector[N] y;
}

transformed parameters {
   matrix[N,N] cov;
   for (i in 1:N)
      for (j in 1:N)
         cov[i,j] <- Eta_sq * exp(-Rho_sq * squared_distance(X[i], X[j])) + if_else(i==j, Sigma_sq, 0.0);
}

model {
   y ~ multi_normal(Mu, cov);
   Class ~ bernoulli_logit(y);
}

generated quantities {
   vector[N_new] y_new;
   vector[N_new] prob;
   vector[N_new] z;
   vector[N_new] Mu_new;
   matrix[N_new,N_new] L;
   {
      matrix[N_new,N_new] Omega;
      matrix[N,N_new] K;
      matrix[N_new,N] K_transpose_div_cov;
      matrix[N_new,N_new] Tau;

      for (i in 1:N_new)
         for (j in 1:N_new)
            Omega[i,j] <- Eta_sq * exp(-Rho_sq * squared_distance(X_new[i], X_new[j])) + if_else(i==j, Sigma_sq, 0.0);

      for (i in 1:N)
         for (j in 1:N_new)
            K[i,j] <- Eta_sq * exp(-Rho_sq * squared_distance(X[i], X_new[j]));

      K_transpose_div_cov <- K' / cov;  # ':transpose
      Mu_new <- K_transpose_div_cov * y;
      Tau <- Omega - K_transpose_div_cov * K;
      for (i in 1:N_new)
         for (j in (i+1):N_new)
            Tau[i,j] <- Tau[j,i];

      L <- cholesky_decompose(Tau);
   }

   for (j in 1:N_new)
      z[j] <- normal_rng(0, 1);
   y_new <- Mu_new + L*z;
   for (j in 1:N_new)
      prob[j] <- inv_logit(y_new[j]);
}

前の記事のGP2.stanとGP4.stanを発展させた形になっています。

  • 2,4,7,26,48,52行目: 入力がD次元ベクトル(今はD=2)になったことに伴う変更です。26,48,52行目で使われているsquared_distance関数は2つのベクトルの二乗距離を返します。
  • 30,31行目: 2値のクラス分類になったことに伴う変更です。
  • 34~69行目:Rからメッシュの値を渡して、各メッシュごとに所属クラスが1となる確率を算出しています。

今回データ点が200点で、200×200の分散共分散行列を使いますが、これは時間がかかる処理です。そのため、フルベイズは諦めてEta_sq,Rho_sq,Sigma_sqに値を与えています。結局所属クラスをロジスティック回帰で算出する際の潜在変数の値yだけを推定しています。

キックするRコードは以下になります。

library(rstan)
library(ggplot2)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

d <- read.delim('input/classification.txt', sep='\t')
N <- nrow(d)
X <- d[ ,c(1,2)]
Class <- d$class
m <- seq(from = -3, to = 3, by = 0.1)
X_new <- expand.grid(X = m, Y = m)
N_new <- nrow(X_new)

data <- list(
   D = ncol(X),
   N = N,
   X = X,
   Class = Class,
   N_new = N_new,
   X_new = X_new,
   Eta_sq = 3.0,
   Rho_sq = 0.5,
   Sigma_sq = 0.1
)

stanmodel <- stan_model(file='model/GP_classification.stan')
fit <- sampling(stanmodel, data = data, pars = c('prob'), iter=1300, warmup=300, seed=123, chains=3)

la <- extract(fit)
N.mcmc <- length(la$lp__)
z.prob <- apply(la$prob, 2, median)
d.test <- data.frame(X_new, Z=z.prob)
d.train <- data.frame(d, Z=0)

p <- ggplot(d.test, aes(x=X, y=Y, z=Z))
p <- p + theme(text=element_text(size=18), legend.position='none')
p <- p + scale_fill_gradient2(midpoint=0.5)
p <- p + geom_tile(aes(fill=Z)) + stat_contour(size=0.5, color='black')
p <- p + geom_point(data=subset(d.train, class==0), aes(shape=factor(4), color=factor(0)), alpha=4/5, size=2)
p <- p + geom_point(data=subset(d.train, class==1), aes(shape=factor(1), color=factor(1)), alpha=4/5, size=2)
p <- p + scale_shape_manual(values=c(1,4))
p <- p + scale_colour_manual(values=c('red','blue'))
p <- p + coord_cartesian(xlim=c(-3,3), ylim=c(-3, 3), expand=FALSE)
ggsave(file='output/PRML-Fig06_12.png', plot=p, dpi=300, width=4, height=3)
  • 10,11行目でメッシュに区切ってX_newを作っています(X_new=61×61=3721)。

フルベイズでないにもかかわらず計算時間はSurface Pro 3(core i5)で1chainあたり約3時間でした。PRML-Fig06_12.pngは以下のようになりました。PRMLの図がほぼ再現できています。

f:id:StatModeling:20201107072630p:plain

Stanにおけるガウス過程が遅いことについて定期的に話題になっていますので(例えば、これ)そのうち改善する方法が出てくるのかもしれません。Stan 2.9.0で使えるようになった変分ベイズ(ADVI)を使う選択肢もあるでしょう。