StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

生存時間分析 - ハザード関数に時間相関の制約を入れる

今回のデータは以下のような、1日ごとに得られる観察打ち切りを含む何らかのイベント発生データです(この記事の最後のRコードで作成しています)。

timecens
11
11
341
341
30
371
431
481
501
501
500

time列はイベント発生の時刻、cens列は打ち切りの場合は0, イベント発生の場合は1です。

このようなデータには生存時間分析の手法がよく使われます。しかしこの記事では生存時間分析の知識を前提とします。すみません。個人的なおすすめの本は以下です。

Web上で閲覧できる情報としては以下がおすすめです。 [1] 千葉大学の汪先生による「生存時間解析入門」 (pdf file) [2] 林さんのスライド「比例ハザードモデルはとってもtricky!」 (Web link) [3] 金明哲先生の連載 「Rと生存時間分析」(Web link1, Web link2)

カプラン・マイヤー法や比例ハザードモデルの不満点は、説明変数である「処理」の効果が時間依存しないと仮定する点、ハザード関数が時刻について滑らかという前提が入っていない点などです。統計モデリングによる拡張としては、ベースラインハザードはそのままで処理の効果を時間依存にするとか色々考えられます。今回は処理ごとにベースラインハザードの形が変わるとして、ある1つの処理の離散的なハザード関数を推定するとします。(条件付き)瞬間死亡率であるハザードが滑らかに変化しないで各時刻ででこぼこのバラバラなのは気持ち悪いと考えて、ハザードの対数に2階差分のマルコフ場モデルを仮定しました。Stanコードは以下になります。

data {
   int<lower=1> N;
   int<lower=1> T;
   int<lower=1, upper=T> Time[N];
   int<lower=0, upper=1> Cens[N];
}

parameters {
   vector<upper=0>[T] log_hazard;
   real<lower=0> s_lh;
}

transformed parameters {
   vector[T+1] log_F;
   log_F[1] <- 0;
   for (t in 2:(T+1))
      log_F[t] <- log_F[t-1] + log1m_exp(log_hazard[t-1]);
}

model {
   for (t in 3:T)
      log_hazard[t] ~ normal(2*log_hazard[t-1] - log_hazard[t-2], s_lh);
   for (n in 1:N)
      increment_log_prob(
         if_else(Cens[n]==1, log_hazard[Time[n]] + log_F[Time[n]], log_F[Time[n]+1])
      );
}

generated quantities {
   vector[T] hazard;
   hazard <- exp(log_hazard);
}
  • 9行目: ハザード関数の対数です。
  • 14行目: Fは生存率関数で、その対数です。尤度の計算がラクになるために算出しています。
  • 21~22行目: 2階差分のマルコフ場モデルの部分です。
  • 24~26行目: 各々のイベントまでの時間・打ち切りまでの時間が互いに独立と考えると、打ち切りのハザード関数は個別に最適化できるので(時間によらない)定数と考えて推定を省きます。するとこのような尤度になります。書籍参照。 なお、対数をとっていないハザード関数そのものにマルコフ場モデルを使うのはうまくいきませんでした。[0,1]の範囲制限が厳しいのだと思います。

キックするRコードは以下になります。

library(rstan)
df <- read.delim("input/data.txt", sep="\t")
N <- nrow(df)
T <- max(df$time)
Time <- df$time
Cens <- df$cens

stan_fit <- stan(file="model/model1.stan", chains=0)
data <- list(N=N, T=T, Time=Time, Cens=Cens)
fit <- stan(
   fit=stan_fit,
   pars=c('hazard', 's_lh'), data=data,
   init=function() { list(log_hazard=rep(log(0.02), T), s_lh=1) },
   chains=3, iter=3500, warmup=500, thin=3, seed=11
)
  • 13行目: 2階差分のマルコフ場モデルは初期値が重要です。とはいえざっくり平らな初期値を与えれば大丈夫でした。計算時間はSurface Pro 3で1chainあたり1分ぐらいでした。

結果は以下になります。 まずデータに対して普通に{survival}survfit関数で推定した場合の生存率曲線は以下の通り。描画は@sinhrksさんの{ggfortify}を使っています。とても使いやすいです!!

f:id:StatModeling:20201114120508p:plain

次にStanの推定結果。 以下は生存率曲線の図です。

f:id:StatModeling:20201114120504p:plain

赤線はMCMCサンプルの中央値、薄いオレンジ帯は同じく95%信用区間、濃いオレンジ帯は同じく50%信用区間です。青の線はカプラン・マイヤー法によるハザードの推定値で、黒の+は打ち切りを表します。

以下はハザード関数の図です。

f:id:StatModeling:20201114120512p:plain

赤線はMCMCサンプルの中央値、薄いオレンジ帯は同じく95%信用区間、濃いオレンジ帯は同じく50%信用区間です。青の点と線はカプラン・マイヤー法によるハザードの推定値、黒の点線はこのデモデータを生成した時の真のハザード関数です。生存者が少なくなってくる後半は推定されたハザード関数と差が出やすいです。これは生存時間分析の難しいところで、1回の乱数から生成されたデータに引きずられるためです。久保さんが実験していた乱歩の時系列データと同じような状況になっているのだと思います。

最後に今回のデータを作成したRコードです。真のハザードはコード中のYに相当します。

set.seed(123)
X <- seq(from=-0.6, to=4.3, by=0.1)
Y <- (X^3 - 4.5*X^2 + X + 20)*0.003 - ifelse(X < 3, 0, 0.01*(X-3)^3)
T <- length(X)
N <- 200

n.risk <- c(N)
for (t in 1:T){
   n.alive <- rbinom(n=1, size=tail(n.risk,1), prob=1-Y[t])
   n.risk <- c(n.risk, n.alive)
}
n.die <- -diff(n.risk)
time.die <- as.vector(unlist(mapply(rep, 1:T, n.die)))
N.die <- sum(n.die)
df.die <- data.frame(time.die=c(time.die, rep(T, N-N.die)), cens.die=c(rep(1, N.die), rep(0, N-N.die)))

Cens.hazard <- 0.01
n.risk.c <- c(N)
for (t in 1:T){
   n.alive.c <- rbinom(n=1, size=tail(n.risk.c,1), prob=1-Cens.hazard)
   n.risk.c <- c(n.risk.c, n.alive.c)
}
n.cens <- -diff(n.risk.c)
time.cens <- as.vector(unlist(mapply(rep, 1:T, n.cens)))
N.cens <- sum(n.cens)
df.cens <- data.frame(time.cens=c(time.cens, rep(T, N-N.cens)), cens.cens=c(rep(1, N.cens), rep(0, N-N.cens)))

df <- cbind(df.die, df.cens[sample.int(N), ])
df <- transform(df,
   time = ifelse(time.cens < time.die, time.cens, time.die),
   cens = ifelse(time.cens < time.die | (cens.die == 0 & cens.cens == 0), 0, 1)
)

write.table(df, file='input/data.txt', sep='\t', quote=FALSE, row.names=FALSE)