Bayesian Lassoで特徴選択
Stanのマニュアルの「Point Estimation」の章を試しましたので記録を残します。
increment_log_prob
関数を使って重回帰をやります。その後、2通りのLassoで特徴選択をします。Stanでやる場合、ロジスティック回帰などにも簡単に組み込めますので拡張性が高いです。
データは以下の通りです。
Y | X.1 | X.2 | X.3 | … | X.28 | X.29 | X.30 |
---|---|---|---|---|---|---|---|
-3.33 | -0.56 | 2.20 | -0.07 | … | -0.60 | -1.56 | 1.00 |
2.87 | -0.23 | 1.31 | -1.17 | … | 0.91 | -1.14 | 0.55 |
4.45 | 1.56 | -0.27 | -0.63 | … | 1.66 | -0.72 | 1.75 |
5.23 | 0.07 | 0.54 | -0.03 | … | -0.04 | 0.53 | 0.62 |
-2.46 | 0.13 | -0.41 | 0.67 | … | -0.33 | -2.34 | 0.32 |
… | … | … | … | … | … | … | … |
6.76 | 2.00 | 1.50 | -1.16 | … | 0.76 | 0.07 | 1.59 |
0.38 | 0.60 | -0.77 | -0.13 | … | 0.63 | -0.30 | -2.37 |
1.61 | -1.25 | 0.85 | -1.94 | … | -0.21 | -0.18 | 0.02 |
-4.74 | -0.61 | -1.26 | 1.18 | … | 1.01 | -0.15 | 0.19 |
-7.20 | -1.19 | -0.35 | 1.86 | … | 1.44 | 0.66 | 0.82 |
データを生成したRコードは以下になります。
set.seed(123) N <- 200 K <- 30 K.important <- 4 X <- matrix(rnorm(N*K, 0, 1), N, K) bx.true <- c(seq(from=3, by=-2, length=K.important), rep(0,K-K.important)) s.true <- 1.0 Y <- X %*% bx.true + rnorm(N, 0, s.true) Y <- as.vector(Y) d <- data.frame(Y=Y, X=X)
説明変数Xは30個あって、データ数は200個です。その説明変数の各々の係数beta
のうち、最初の4つだけ真の値はノンゼロです。X
とbeta
を掛けたものにs.true=1
のノイズを加えてY
を作っています。
はじめにincrement_log_prob
を使った最尤推定をしてみます。Stanコードは以下になります。
data { int<lower=1> N; int<lower=1> K; matrix[N,K] X; vector[N] Y; } parameters { vector[K] beta; } transformed parameters { real<lower=0> squared_error; squared_error <- dot_self(Y - X * beta); } model { increment_log_prob(-squared_error); }
- 8行目:
beta
が求めるべき回帰係数のパラメータです。 - 11,12行目:
squared_error
を算出しています。 - 15行目:
increment_log_prob
におもむろにブチこむことでsquared_error
が小さいところを探しに行ってくれます。一種の最適化とみなすわけです。
キックするRコードは省略します。結果は以下になります。うーん、すでにかなりうまくいっていますね…。
mean | se_mean | sd | X2.5. | X25. | X50. | X75. | X97.5. | n_eff | Rhat | |
---|---|---|---|---|---|---|---|---|---|---|
beta[1] | 2.94 | 0.00 | 0.05 | 2.84 | 2.91 | 2.94 | 2.98 | 3.05 | 1500 | 1.00 |
beta[2] | 1.10 | 0.00 | 0.05 | 1.00 | 1.06 | 1.10 | 1.13 | 1.20 | 1500 | 1.00 |
beta[3] | -0.97 | 0.00 | 0.06 | -1.09 | -1.01 | -0.97 | -0.93 | -0.86 | 1500 | 1.00 |
beta[4] | -3.00 | 0.00 | 0.05 | -3.10 | -3.03 | -3.00 | -2.96 | -2.90 | 1500 | 1.00 |
beta[5] | -0.03 | 0.00 | 0.05 | -0.14 | -0.07 | -0.03 | 0.01 | 0.07 | 1500 | 1.00 |
beta[6] | 0.07 | 0.00 | 0.05 | -0.04 | 0.03 | 0.07 | 0.10 | 0.17 | 1500 | 1.00 |
beta[7] | 0.05 | 0.00 | 0.06 | -0.07 | 0.01 | 0.05 | 0.09 | 0.18 | 1500 | 1.00 |
beta[8] | 0.04 | 0.00 | 0.06 | -0.07 | 0.00 | 0.04 | 0.08 | 0.16 | 1500 | 1.00 |
beta[9] | -0.07 | 0.00 | 0.05 | -0.16 | -0.10 | -0.07 | -0.03 | 0.03 | 1500 | 1.00 |
beta[10] | -0.05 | 0.00 | 0.05 | -0.15 | -0.09 | -0.05 | -0.02 | 0.04 | 1500 | 1.00 |
beta[11] | 0.04 | 0.00 | 0.06 | -0.07 | 0.00 | 0.04 | 0.08 | 0.15 | 1500 | 1.00 |
beta[12] | -0.02 | 0.00 | 0.06 | -0.13 | -0.06 | -0.02 | 0.02 | 0.09 | 1500 | 1.00 |
beta[13] | -0.16 | 0.00 | 0.05 | -0.27 | -0.20 | -0.16 | -0.12 | -0.06 | 1500 | 1.00 |
beta[14] | -0.07 | 0.00 | 0.06 | -0.17 | -0.10 | -0.07 | -0.03 | 0.04 | 1500 | 1.00 |
beta[15] | -0.02 | 0.00 | 0.05 | -0.13 | -0.06 | -0.03 | 0.01 | 0.08 | 1500 | 1.00 |
beta[16] | -0.01 | 0.00 | 0.05 | -0.11 | -0.05 | -0.01 | 0.02 | 0.09 | 1500 | 1.00 |
beta[17] | -0.10 | 0.00 | 0.05 | -0.20 | -0.13 | -0.10 | -0.06 | 0.01 | 1500 | 1.00 |
beta[18] | 0.08 | 0.00 | 0.05 | -0.02 | 0.05 | 0.08 | 0.12 | 0.18 | 1500 | 1.00 |
beta[19] | -0.12 | 0.00 | 0.06 | -0.24 | -0.16 | -0.13 | -0.09 | -0.01 | 1500 | 1.00 |
beta[20] | 0.04 | 0.00 | 0.05 | -0.06 | 0.01 | 0.04 | 0.07 | 0.14 | 1500 | 1.00 |
beta[21] | 0.16 | 0.00 | 0.06 | 0.05 | 0.12 | 0.16 | 0.20 | 0.27 | 1471 | 1.00 |
beta[22] | -0.10 | 0.00 | 0.05 | -0.21 | -0.14 | -0.10 | -0.06 | 0.00 | 1500 | 1.00 |
beta[23] | 0.03 | 0.00 | 0.05 | -0.08 | -0.01 | 0.03 | 0.07 | 0.13 | 1500 | 1.00 |
beta[24] | 0.00 | 0.00 | 0.06 | -0.11 | -0.04 | 0.00 | 0.03 | 0.10 | 1500 | 1.00 |
beta[25] | -0.01 | 0.00 | 0.05 | -0.11 | -0.05 | -0.01 | 0.02 | 0.09 | 1500 | 1.00 |
beta[26] | -0.08 | 0.00 | 0.06 | -0.19 | -0.12 | -0.08 | -0.04 | 0.04 | 1500 | 1.00 |
beta[27] | -0.12 | 0.00 | 0.05 | -0.22 | -0.16 | -0.12 | -0.09 | -0.02 | 1500 | 1.00 |
beta[28] | 0.07 | 0.00 | 0.05 | -0.03 | 0.03 | 0.07 | 0.10 | 0.17 | 1500 | 1.00 |
beta[29] | -0.07 | 0.00 | 0.05 | -0.18 | -0.11 | -0.07 | -0.03 | 0.03 | 1500 | 1.00 |
beta[30] | -0.07 | 0.00 | 0.05 | -0.17 | -0.10 | -0.07 | -0.03 | 0.04 | 1500 | 1.00 |
squared_error | 207.06 | 0.19 | 3.97 | 200.53 | 204.09 | 206.69 | 209.52 | 215.79 | 428 | 1.00 |
lp__ | -207.06 | 0.19 | 3.97 | -215.79 | -209.52 | -206.69 | -204.09 | -200.53 | 428 | 1.00 |
気を取り直して次にこれをLassoで解きます。係数beta
の絶対値が大きいところにペナルティを課します。Stanコードは以下になります。
data { int<lower=1> N; int<lower=1> K; matrix[N,K] X; vector[N] Y; real<lower=0> Lambda; } parameters { vector[K] beta; } transformed parameters { real<lower=0> squared_error; squared_error <- dot_self(Y - X * beta); } model { increment_log_prob(-squared_error); for (k in 1:K) increment_log_prob(-Lambda * abs(beta[k])); }
- 6行目:
Lambda
はペナルティとsquared_error
の間の重みを変えるパラメータです。ここではデータとしてLambda=100を与えました。大きければ大きいほどノンゼロのbeta
の数が減ります。 ・18行目: ペナルティの分をincrement_log_prob
で足しこんでいます。
結果は以下になります。
mean | se_mean | sd | X2.5. | X25. | X50. | X75. | X97.5. | n_eff | Rhat | |
---|---|---|---|---|---|---|---|---|---|---|
beta[1] | 2.66 | 0.00 | 0.05 | 2.57 | 2.62 | 2.66 | 2.70 | 2.76 | 941 | 1.00 |
beta[2] | 0.81 | 0.00 | 0.05 | 0.71 | 0.77 | 0.81 | 0.84 | 0.91 | 929 | 1.00 |
beta[3] | -0.77 | 0.00 | 0.05 | -0.87 | -0.80 | -0.77 | -0.74 | -0.67 | 1131 | 1.00 |
beta[4] | -2.77 | 0.00 | 0.05 | -2.87 | -2.80 | -2.77 | -2.73 | -2.67 | 1236 | 1.00 |
beta[5] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.03 | 763 | 1.00 |
beta[6] | 0.01 | 0.00 | 0.01 | -0.02 | 0.00 | 0.00 | 0.01 | 0.04 | 921 | 1.00 |
beta[7] | 0.00 | 0.00 | 0.01 | -0.02 | 0.00 | 0.00 | 0.01 | 0.04 | 1029 | 1.00 |
beta[8] | 0.00 | 0.00 | 0.01 | -0.02 | -0.01 | 0.00 | 0.01 | 0.03 | 961 | 1.00 |
beta[9] | -0.01 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 916 | 1.00 |
beta[10] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.00 | 0.02 | 552 | 1.01 |
beta[11] | 0.00 | 0.00 | 0.01 | -0.03 | 0.00 | 0.00 | 0.01 | 0.03 | 798 | 1.00 |
beta[12] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.00 | 0.02 | 1018 | 1.00 |
beta[13] | -0.02 | 0.00 | 0.02 | -0.08 | -0.03 | -0.01 | 0.00 | 0.01 | 530 | 1.00 |
beta[14] | -0.01 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 698 | 1.00 |
beta[15] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.02 | 1035 | 1.00 |
beta[16] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.03 | 935 | 1.00 |
beta[17] | -0.01 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 719 | 1.00 |
beta[18] | 0.00 | 0.00 | 0.01 | -0.02 | 0.00 | 0.00 | 0.01 | 0.03 | 831 | 1.00 |
beta[19] | -0.01 | 0.00 | 0.02 | -0.07 | -0.02 | -0.01 | 0.00 | 0.01 | 418 | 1.01 |
beta[20] | 0.00 | 0.00 | 0.01 | -0.02 | -0.01 | 0.00 | 0.01 | 0.03 | 1044 | 1.00 |
beta[21] | 0.02 | 0.00 | 0.02 | -0.01 | 0.00 | 0.02 | 0.03 | 0.08 | 767 | 1.00 |
beta[22] | -0.01 | 0.00 | 0.02 | -0.06 | -0.02 | -0.01 | 0.00 | 0.01 | 750 | 1.00 |
beta[23] | 0.00 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 897 | 1.01 |
beta[24] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.03 | 1064 | 1.00 |
beta[25] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.03 | 998 | 1.00 |
beta[26] | 0.00 | 0.00 | 0.01 | -0.03 | -0.01 | 0.00 | 0.01 | 0.02 | 585 | 1.00 |
beta[27] | -0.01 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 1060 | 1.00 |
beta[28] | 0.00 | 0.00 | 0.01 | -0.02 | 0.00 | 0.00 | 0.01 | 0.04 | 882 | 1.00 |
beta[29] | -0.01 | 0.00 | 0.02 | -0.05 | -0.02 | -0.01 | 0.00 | 0.02 | 501 | 1.00 |
beta[30] | -0.01 | 0.00 | 0.01 | -0.04 | -0.01 | 0.00 | 0.00 | 0.02 | 810 | 1.00 |
squared_error | 270.81 | 0.30 | 10.02 | 252.24 | 263.65 | 270.69 | 277.38 | 291.43 | 1125 | 1.00 |
lp__ | -1001.60 | 0.28 | 5.07 | -1012.58 | -1004.81 | -1001.02 | -997.96 | -993.31 | 322 | 1.01 |
最後にオーソドックスにbetaの事前分布にLaplace分布(double exponential分布)を設定することでLassoを行う例を紹介します。色々な場面で応用がききます。Stanコードは以下になります。
data { int<lower=1> N; int<lower=1> K; matrix[N,K] X; vector[N] Y; real<lower=0> S_beta; } parameters { vector[K] beta; real<lower=0> s_y; } model { beta ~ double_exponential(0, S_beta); Y ~ normal(X * beta, s_y); }
- 6行目: double_exponentialの引数です。ここではデータとして
S_beta
=0.03を与えました。小さければ小さいほどノンゼロのbeta
の数が減ります。
結果は以下になります。
mean | se_mean | sd | X2.5. | X25. | X50. | X75. | X97.5. | n_eff | Rhat | |
---|---|---|---|---|---|---|---|---|---|---|
beta[1] | 2.69 | 0.00 | 0.09 | 2.51 | 2.63 | 2.69 | 2.75 | 2.87 | 937 | 1.00 |
beta[2] | 0.84 | 0.00 | 0.09 | 0.64 | 0.78 | 0.84 | 0.90 | 1.01 | 1021 | 1.00 |
beta[3] | -0.79 | 0.00 | 0.09 | -0.96 | -0.85 | -0.79 | -0.73 | -0.61 | 1211 | 1.00 |
beta[4] | -2.79 | 0.00 | 0.08 | -2.95 | -2.85 | -2.79 | -2.73 | -2.63 | 1168 | 1.00 |
beta[5] | 0.00 | 0.00 | 0.03 | -0.07 | -0.02 | 0.00 | 0.02 | 0.07 | 799 | 1.01 |
beta[6] | 0.01 | 0.00 | 0.04 | -0.05 | -0.01 | 0.01 | 0.03 | 0.09 | 938 | 1.00 |
beta[7] | 0.01 | 0.00 | 0.04 | -0.07 | -0.01 | 0.01 | 0.03 | 0.10 | 965 | 1.00 |
beta[8] | 0.00 | 0.00 | 0.04 | -0.07 | -0.02 | 0.00 | 0.02 | 0.09 | 1077 | 1.00 |
beta[9] | -0.02 | 0.00 | 0.04 | -0.11 | -0.04 | -0.01 | 0.00 | 0.05 | 795 | 1.00 |
beta[10] | -0.01 | 0.00 | 0.04 | -0.09 | -0.03 | -0.01 | 0.01 | 0.06 | 1034 | 1.00 |
beta[11] | 0.01 | 0.00 | 0.04 | -0.07 | -0.01 | 0.00 | 0.02 | 0.09 | 1058 | 1.00 |
beta[12] | -0.01 | 0.00 | 0.03 | -0.08 | -0.03 | 0.00 | 0.01 | 0.06 | 1067 | 1.00 |
beta[13] | -0.04 | 0.00 | 0.05 | -0.15 | -0.07 | -0.03 | 0.00 | 0.03 | 920 | 1.00 |
beta[14] | -0.02 | 0.00 | 0.04 | -0.10 | -0.03 | -0.01 | 0.01 | 0.05 | 1033 | 1.00 |
beta[15] | 0.00 | 0.00 | 0.03 | -0.08 | -0.02 | 0.00 | 0.01 | 0.07 | 1419 | 1.00 |
beta[16] | 0.00 | 0.00 | 0.03 | -0.07 | -0.02 | 0.00 | 0.02 | 0.07 | 792 | 1.00 |
beta[17] | -0.01 | 0.00 | 0.04 | -0.10 | -0.04 | -0.01 | 0.01 | 0.06 | 1062 | 1.00 |
beta[18] | 0.00 | 0.00 | 0.03 | -0.07 | -0.01 | 0.00 | 0.02 | 0.08 | 1118 | 1.00 |
beta[19] | -0.03 | 0.00 | 0.04 | -0.13 | -0.05 | -0.02 | 0.00 | 0.04 | 713 | 1.01 |
beta[20] | 0.00 | 0.00 | 0.03 | -0.06 | -0.01 | 0.00 | 0.02 | 0.08 | 1175 | 1.00 |
beta[21] | 0.05 | 0.00 | 0.05 | -0.03 | 0.01 | 0.04 | 0.08 | 0.17 | 736 | 1.00 |
beta[22] | -0.03 | 0.00 | 0.04 | -0.13 | -0.05 | -0.02 | 0.00 | 0.05 | 807 | 1.00 |
beta[23] | -0.01 | 0.00 | 0.03 | -0.08 | -0.02 | -0.01 | 0.01 | 0.05 | 901 | 1.01 |
beta[24] | 0.00 | 0.00 | 0.03 | -0.08 | -0.02 | 0.00 | 0.01 | 0.06 | 806 | 1.00 |
beta[25] | 0.00 | 0.00 | 0.03 | -0.08 | -0.02 | 0.00 | 0.01 | 0.07 | 759 | 1.00 |
beta[26] | -0.01 | 0.00 | 0.03 | -0.08 | -0.02 | 0.00 | 0.01 | 0.06 | 1161 | 1.00 |
beta[27] | -0.02 | 0.00 | 0.04 | -0.11 | -0.04 | -0.01 | 0.00 | 0.05 | 763 | 1.00 |
beta[28] | 0.01 | 0.00 | 0.04 | -0.05 | -0.01 | 0.01 | 0.03 | 0.10 | 980 | 1.00 |
beta[29] | -0.02 | 0.00 | 0.04 | -0.12 | -0.04 | -0.02 | 0.00 | 0.04 | 998 | 1.00 |
beta[30] | -0.01 | 0.00 | 0.04 | -0.09 | -0.03 | -0.01 | 0.01 | 0.06 | 1111 | 1.00 |
s_y | 1.15 | 0.00 | 0.07 | 1.03 | 1.11 | 1.15 | 1.20 | 1.31 | 905 | 1.00 |
lp__ | -390.43 | 0.25 | 5.07 | -400.78 | -393.84 | -390.21 | -386.89 | -381.46 | 425 | 1.00 |