StatModeling Memorandum

StatModeling Memorandum

StanとRとPythonでベイズ統計モデリングします. たまに書評.

PythonのSymPyで変分ベイズの例題を理解する

この記事の続きです。

ここではPRMLの10.1.3項の一変数ガウス分布の例題(WikipediaVariational_Bayesian_methodsのA basic exampleと同じ)をSymPyで解きます。すなわちデータが

  Y_n \sim Normal(\mu, \tau^{-1})   n = 1,..,N

に従い*1 \mu \tauが、

  \mu \sim Normal(\mu_0, (\lambda_0 \tau)^{-1})

  \tau \sim Gamma(a_0, b_0)

に従うという状況です。ここでデータ Y_n n=1,...,N)が得られたとして事後分布 p(\mu, \tau | \boldsymbol{Y})を変分ベイズで求めます。

まずはじめに、上記の確率モデルから同時分布 p(\boldsymbol{Y}, \mu, \tau)を書き下しておきます。

  p(\boldsymbol{Y}, \mu, \tau) = p(\boldsymbol{Y} | \mu, \tau) p(\mu | \tau) p(\tau)

なので、

  p(\boldsymbol{Y}, \mu, \tau) = \prod_{n=1}^N Normal(Y_n | \mu, \tau^{-1}) \cdot Normal(\mu | \mu_0, (\lambda_0 \tau)^{-1}) \cdot Gamma(\tau | a_0, b_0)

となります。

この問題は単純なので事後分布は厳密に求まるのですが、ここでは変分ベイズで解きます。すなわち、事後分布 p(\mu, \tau | \boldsymbol{Y}) q(\mu, \tau)で近似します。さらに q(\mu, \tau) = q(\mu) q(\tau)と因子分解可能と仮定します。そして、前の記事の最後の2つの式を使って、 q(\mu) q(\tau)が収束するまで繰り返し交互に更新して求めるのでした。以下ではこれをSymPyでやります。

from sympy import *
from sympy.stats import *
init_printing(use_unicode=True)

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'
  • 3行目: 僕は基本的にはJupyter Notebookで実行しています。この行を追加することで、数式がMathJaxで綺麗に表示されます。
  • 5~6行目: セルの途中で出力しても数式が綺麗に表示されるようにしています。こちらの記事を参考にしました。
y, mu, mu0 = symbols('y mu mu0', real=True)
Y_vec = symbols('Y1:4', real=True)
tau, lambda0, a0, b0 = symbols('tau lambda0 a0 b0', positive=True)
  • 1行目: SymPyで使う変数はsymbols関数で作成しておく必要があります。real=Trueと指定することで、実数と仮定することができます。何も指定しなければ複素数になります。このように仮定を入れておかないと、のちの式変形や積分がうまくいかない場合があります。
  • 2行目: このように変数のリストを作成することもできます。
  • 3行目: positive=Trueと指定することで、正の実数だと仮定することができます。

なお、SymPyでは要素数やデータ数をNとするような一般の場合の式変形は基本的に難しいです。しかし具体的な値に決めれば実行できます。そこで、ここでは2行目でデータ数をY1,Y2,Y33個として先に進めます。あとで値を色々変えて試すとNの場合の見当がつくので、そこから一般の場合を証明することもできます。

p_y = density(Normal('', mu, 1/sqrt(tau)))(y)
p_mu = density(Normal('', mu0, 1/sqrt(lambda0*tau)))(mu)
p_tau = density(Gamma('', a0, 1/b0))(tau)

sympy.statsには確率分布の密度関数の式がありますので、それを使っています。ここではデータ1つあたりのyの分布とmutauの事前分布を定義しています。

SymPyの正規分布Normal(平均, 標準偏差)なので、精度であるtau1/sqrt(tau)を代入しています。また、PRMLWikipediaのガンマ分布はGamma(shape, rate)である一方*2、SymPyのガンマ分布はGamma(shape, scale)なので、1/b0を代入しています。

integrate(p_mu, (mu, -oo, oo))
simplify(integrate(p_tau, (tau, 0, oo)))

試しにmuの分布を -\inftyから \inftyまで積分してみましょう。期待通り1が返ります。tauの分布でも同様に積分すると1にならずに整理されていない式が返ってきますが、simplify関数で整理すると1になります。

同時分布の対数( log\ p)の準備

前の記事の最後の2つの式でやっていることを日本語で書くと以下です。

  • 同時分布の対数 log\ p(\boldsymbol{Y}, \mu, \tau) q(\tau)を掛けて \tau積分して、 \muの分布 q(\mu)を求める。
  • 同時分布の対数 log\ p(\boldsymbol{Y}, \mu, \tau) q(\mu)を掛けて \mu積分して、 \tauの分布 q(\tau)を求める。

そこでまず同時分布の対数を準備します。

log_p = sum([log(p_y.subs(y, x)) for x in Y_vec]) + log(p_mu) + log(p_tau)
log_p = simplify(log_p)
log_p

  - \frac{Y_1^{2} \tau}{2} + Y_1 \mu \tau - \frac{Y_2^{2} \tau}{2} + Y_2 \mu \tau - \frac{Y_3^{2} \tau}{2} + Y_3 \mu \tau + a_0 \log{\left (b_0 \right )} + a_0 \log{\left (\tau \right )} - b_0 \tau \\ - \frac{\lambda_0 \tau}{2} \mu^{2} + \lambda_0 \mu \mu_0 \tau - \frac{\lambda_0 \tau}{2} \mu_0^{2} - \frac{3 \tau}{2} \mu^{2} + \log{\left (a_0 \right )} + \frac{1}{2} \log{\left (\lambda_0 \right )} + \log{\left (\tau \right )} - \log{\left (\Gamma{\left(a_0 + 1 \right)} \right )} \\ - 2 \log{\left (\pi \right )} - 2 \log{\left (2 \right )}

  • 1行目: expr.subs(y, x)は式expryxを代入します。

次に積分にすすみます。

 \muを含まない項に \muを含まない分布 q(\tau)を掛けて \tau積分したところで、やはり \muに関係がない定数になります。定数は最後に規格化して求めればよいので、途中の計算はなるべく簡単になるように余計な項を取り除きます。これがSymPyで計算をうまくさせるポイントになります。

log_p_for_mu = integrate(diff(log_p, mu), mu)
log_p_for_mu = collect(log_p_for_mu, mu)
log_p_for_mu

  \mu^{2} \left(- \frac{\lambda_0 \tau}{2} - \frac{3 \tau}{2}\right) + \mu \left(Y_1 \tau + Y_2 \tau + Y_3 \tau + \lambda_0 \mu_0 \tau\right)

  • 1行目: log_pmu微分してmu積分することで、muを含まない項を取り除いています。
  • 2行目: collect関数はmuの関数として式をみたときに共通部分をくくります。
log_p_for_tau = integrate(diff(log_p, tau), tau)
log_p_for_tau = collect(log_p_for_tau, tau)
log_p_for_tau

  \tau \left(- \frac{Y_1^{2}}{2} + Y_1 \mu - \frac{Y_2^{2}}{2} + Y_2 \mu - \frac{Y_3^{2}}{2} + Y_3 \mu - b_0 - \frac{\lambda_0 \mu^{2}}{2} + \lambda_0 \mu \mu_0 - \frac{\lambda_0 \mu_0^{2}}{2} - \frac{3 \mu^{2}}{2}\right) + \left(a_0 + 1\right) \log{\left (\tau \right )}

 \mu積分して q(\tau)を求める方も同様なのでそうしておきます。

できるところまで解析的に求める

SymPyの練習のため、事前分布から積分を1回実行して q(\mu) q(\tau)を求めるところをやってみます。

log_q1_mu = integrate(log_p_for_mu * p_tau, (tau, 0, oo))
log_q1_mu
log_q1_mu = simplify(log_q1_mu)
log_q1_mu
log_q1_mu = collect(expand(log_q1_mu), mu)
log_q1_mu

  \mu^{2} \left(- \frac{a_0 \lambda_0}{2 b_0} - \frac{3 a_0}{2 b_0}\right) + \mu \left(\frac{Y_1 a_0}{b_0} + \frac{Y_2 a_0}{b_0} + \frac{Y_3 a_0}{b_0} + \frac{a_0 \mu_0}{b_0} \lambda_0\right)

  • 5行目: 3行目でsimplifyしていますが、解析者が意図しない形になることはよくあります。ここでは、expandしてcollectすることでmu多項式にしています。

log_q1_muの式は \muの二次関数のマイナスなので、このすぐあとのq1_mu正規分布になることが分かります。共役事前分布を使っているからです。規格化定数をもとめて規格化しましょう。

q1_mu = exp(log_q1_mu)
const = simplify(integrate(q1_mu, (mu, -oo, oo)))
const
q1_mu = 1/const * exp(log_q1_mu)
q1_mu

constが規格化定数になります。以下の部分です。

  \frac{\sqrt{2} \sqrt{\pi} \sqrt{b_0}}{\sqrt{a_0} \sqrt{\lambda_0 + 3}} e^{\frac{a_0 \left(Y_1 + Y_2 + Y_3 + \lambda_0 \mu_0\right)^{2}}{2 b_0 \left(\lambda_0 + 3\right)}}

q1_muは規格化された分布の q(\mu)です。以下になります。

  \frac{\sqrt{2} \sqrt{a_0} \sqrt{\lambda_0 + 3}}{2 \sqrt{\pi} \sqrt{b_0}} e^{- \frac{a_0 \left(Y_1 + Y_2 + Y_3 + \lambda_0 \mu_0\right)^{2}}{2 b_0 \left(\lambda_0 + 3\right)}} e^{\mu^{2} \left(- \frac{a_0 \lambda_0}{2 b_0} - \frac{3 a_0}{2 b_0}\right) + \mu \left(\frac{Y_1 a_0}{b_0} + \frac{Y_2 a_0}{b_0} + \frac{Y_3 a_0}{b_0} + \frac{a_0 \mu_0}{b_0} \lambda_0\right)}

同じようにq1_tauを求めます。変分ベイズの手順としては、上で求めたばかりの q(\mu)を掛けて \mu積分します。しかしSymPyではその計算は重くて実行できないので、ここでは \muの事前分布p_muを使ってq1_tauを求めてみます。

log_q1_tau = integrate(log_p_for_tau * p_mu, (mu, -oo, oo))
log_q1_tau
log_q1_tau = integrate(diff(log_q1_tau, tau), tau)
log_q1_tau
log_q1_tau = collect(log_q1_tau, tau)
log_q1_tau
  • 3行目: あとで規格化定数を求めればよいので定数項は取り除いておきます。

このすぐあとのq1_tauはガンマ分布になることが分かります。これも共役事前分布を使っているからです。

q1_tau = logcombine(exp(log_q1_tau))
q1_tau
# const = integrate(q1_tau, (tau, 0, oo))
# const
# q1_tau = 1/const * q1_tau
# q1_tau
  • 1行目: logcombine関数を使うことで exp(log(x)) xにします。simplify関数だとこの変形をやってくれないことがあります。
  • 3行目: これで素直に積分できればよいのですが、残念ながらできません。

q1_tauは以下です。

  \tau^{a_0 + 1} e^{\tau \left(- \frac{Y_1^{2}}{2} + Y_1 \mu_0 - \frac{Y_2^{2}}{2} + Y_2 \mu_0 - \frac{Y_3^{2}}{2} + Y_3 \mu_0 - b_0 - \frac{3 \mu_0^{2}}{2}\right)}

この expの肩にのっている \tauの係数が負だとSymPyが分からないから積分できないようです。ちなみにこのあたりはMathematicaの方が圧倒的に賢くて、例えば以下の入力できちんと積分できます。

Integrate[tau^(a+1)*Exp[tau * (-1/2*x^2 + x*mu - 1/2* y^2 + y*mu - b - mu^2)], {tau, 0, Infinity}, Assumptions -> {b > 0, a > 0, Element[x, Reals], Element[y, Reals], Element[mu, Reals]} ]

これをうまく積分させるには、 \tauの係数が負であることを確認してから変数で置き換えて実行します。

まず \tauの係数が負であることを確認します。

coef = collect(log_q1_tau, tau).coeff(tau)
coef
sol = solve(diff(coef, Y_vec[0]), Y_vec[0])[0]
sol #=> mu0
replacements = [(var, sol) for var in Y_vec]
coef.subs(replacements) #=> -b0
  • 1行目:  \tauの係数coefを取得しています。
  • 3行目: coefの最大値が負であることを示せばOKです。まずはY_vec[0]についてcoefが最大になる値を探します。それは微分して0(&2階微分が負)になる点を求めればOKです。Y_vec[0]と他のY_vec[*]は区別がある形ではないので、Y_vec[*]についても同じ点でcoefが最大となります。
  • 5~6行目: それをまとめて代入しています。最大値は-b0と分かるので、 \tauの係数は負であることがわかります。

次に変数で置き換えて積分します。

xi = symbols('xi', positive=True)
const = simplify(integrate(tau**(a0+1)*exp(-xi*tau), (tau, 0, oo)))
const = const.subs(xi, -coef)
const
q1_tau = 1/const * q1_tau
q1_tau

constが規格化定数になります。以下の部分です。

  \left(\frac{Y_1^{2}}{2} - Y_1 \mu_0 + \frac{Y_2^{2}}{2} - Y_2 \mu_0 + \frac{Y_3^{2}}{2} - Y_3 \mu_0 + b_0 + \frac{3 \mu_0^{2}}{2}\right)^{- a_0 - 2} \Gamma{\left(a_0 + 2 \right)}

q1_tauは正規化された分布の q(\tau)です。以下になります。

  \frac{\tau^{a_0 + 1}}{\Gamma{\left(a_0 + 2 \right)}} \left(\frac{Y_1^{2}}{2} - Y_1 \mu_0 + \frac{Y_2^{2}}{2} - Y_2 \mu_0 + \frac{Y_3^{2}}{2} - Y_3 \mu_0 + b_0 + \frac{3 \mu_0^{2}}{2}\right)^{a_0 + 2} e^{\tau \left(- \frac{Y_1^{2}}{2} + Y_1 \mu_0 - \frac{Y_2^{2}}{2} + Y_2 \mu_0 - \frac{Y_3^{2}}{2} + Y_3 \mu_0 - b_0 - \frac{3 \mu_0^{2}}{2}\right)}

このように解析解を求めることはコンセプトの理解に役立ちます。一方で、積分を繰り返して事後分布 q(\mu, \tau)が収束するか確認するようなことは数値的に求めた方が分かりやすいです。

数値的に求める

仮に得られたデータY_vec1.1,1.0,1.3とします。また、事前分布はa0 = 1, b0 = 1, mu0 = 0, lambda0 = 1とします。

replacements = [(a0, 1), (b0, 1), (mu0, 0), (lambda0, 1)]
data_vec = [1.1, 1.0, 1.3]
replacements.extend([(var, val) for var, val in zip(Y_vec, data_vec)])
log_p_for_mu_subs = log_p_for_mu.subs(replacements)
log_p_for_tau_subs = log_p_for_tau.subs(replacements)
[log_p_for_mu_subs, log_p_for_tau_subs]

  \left [ - 2 \mu^{2} \tau + 3.4 \mu \tau, \quad \tau \left(- 2 \mu^{2} + 3.4 \mu - 2.95\right) + 2 \log{\left (\tau \right )}\right ]

  • 1行目: 事前分布の分の代入を作っています。
  • 2~3行目: データの分の代入を追加しています。

 \tauの初期分布をp_tauとして、 q(\mu)を求める→ q(\tau)を求める→ q(\mu)を求める→...と7回ほど繰り返してみます。

q_tau = N(p_tau.subs(replacements))
q_tau

for i in range(7):
    log_q_mu = N(integrate(log_p_for_mu_subs * q_tau, (tau, 0, oo)))
    const = N(integrate(exp(log_q_mu), (mu, -oo, oo)))
    q_mu = 1/const * exp(log_q_mu)

    log_q_tau = N(integrate(log_p_for_tau_subs * q_mu, (mu, -oo, oo)))
    const = N(integrate(exp(log_q_tau), (tau, 0, oo)))
    q_tau = 1/const * exp(log_q_tau)

    [q_mu, q_tau]

  \left [ 0.188098154753774 e^{- 2.0 \mu^{2} + 3.4 \mu}, \quad 4.03007506250001 \tau^{2.0} e^{- 2.005 \tau}\right ]   \left [ 0.112320150163227 e^{- 2.99251870324189 \mu^{2} + 5.08728179551122 \mu}, \quad 3.11052191637731 \tau^{2.0} e^{- 1.83916666666667 \tau}\right ]   \left [ 0.0965024138432034 e^{- 3.26234707748074 \mu^{2} + 5.54599003171727 \mu}, \quad 2.97238456804457 \tau^{2.0} e^{- 1.81152777777778 \tau}\right ]   \left [ 0.0938011432750369 e^{- 3.31212144445296 \mu^{2} + 5.63060645557004 \mu}, \quad 2.94976700750461 \tau^{2.0} e^{- 1.8069212962963 \tau}\right ]   \left [ 0.0933494031271016 e^{- 3.32056521349236 \mu^{2} + 5.64496086293701 \mu}, \quad 2.94600860516469 \tau^{2.0} e^{- 1.80615354938272 \tau}\right ]   \left [ 0.0932740721319143 e^{- 3.32197669575248 \mu^{2} + 5.64736038277921 \mu}, \quad 2.94538251532282 \tau^{2.0} e^{- 1.80602559156379 \tau}\right ]   \left [ 0.0932615158350226 e^{- 3.32221205946743 \mu^{2} + 5.64776050109463 \mu}, \quad 2.94527817564072 \tau^{2.0} e^{- 1.80600426526063 \tau}\right ]

  • 1行目: N関数は数値による近似を求める関数です。

7回ほどの繰り返しのあとでほぼ収束していそうなことがわかります。

最後に求めた事後分布(の近似) q(\mu, \tau) = q(\mu) q(\tau)を可視化してみましょう。SymPyにもsympy.plottingsympy.plotting.plotが存在するのですが、ちょっと凝った図を書こうとするとすぐ厳しくなってしまいます。そこで、得られた事後分布をlambdify関数で関数化し、NumPyとMatplotlibで描くのが拡張性が高くてオススメです。

from sympy.utilities.lambdify import lambdify
import numpy as np
import matplotlib.pyplot as plt

delta = 0.05
x = np.arange(-1.0, 3.0, delta)
y = np.arange(0.0, 6.0, delta)
X, Y = np.meshgrid(x, y)
func = lambdify((mu, tau), q_mu * q_tau, 'numpy')
Z = func(X, Y)

plt.figure()
CS = plt.contour(X, Y, Z)
plt.clabel(CS, inline=1, fontsize=10)

まとめ

  • SymPyはデータサイエンスや機械学習の書籍や論文を読み進める上で、非常に有用な補助ツールです。
  • 現状では細かいところでMathematicaにまだ負けていると思います。プロにはMathematicaがオススメ。オープンソース重視の人やPython好きな人にはSymPyがオススメ。
  • 式変形には「一般的な場合のようにコンセプトが重要で深く理解しなければならない式変形」と「SymPyなどの数式処理ソフトで追えれば十分であるような式変形」があると個人的に思っています。専門書や技術書を執筆する場合は、その二つを区別すると読者にとって親切かなぁと思いました。

Enjoy!

謝辞

北大電子研の佐藤勝彦氏に感謝します。僕が院生の頃に輪読していた ニコリス プリゴジーヌ『散逸構造』の例題をMathematicaで10分ぐらいで一般解を求めるという衝撃のデモを見せてもらい、その後もたまにMathematicaを教えてもらい、数式処理を学ぶきっかけをもらいました。

*1:いつもはStanとの相性を考えて Normal(平均, 標準偏差)で書いてますが、この記事では Normal(平均, 分散)で書きます。

*2:Stanもね。

統計モデリングで癌の5年生存率データから良い病院を探す

概要

2017年8月9日に国立がん研究センターは、がん治療拠点の約半数にあたる全国188の病院について、癌患者の5年後の生存率データを初めて公表しました(毎日新聞の記事)。報告書は国立がん研究センターが運営するウェブサイトからダウンロードできます(ここ)。報告書をダウンロードしようとすると注意点を記したポップアップが表示されます。大切な部分を抜粋すると以下です。

本報告書には、施設別の生存率を表示していますが、進行がんの多い少ない、高齢者の多い少ないなど、施設毎に治療している患者さんの構成が異なります。そのため、単純に生存率を比較して、その施設の治療成績の良し悪しを論ずることはできません。

一般に高齢者が多い病院ほど、進行癌(ステージが進んだ癌)が多い病院ほど、その病院の生存率は下がるわけです。それならば、統計モデリングで年齢と進行度(ステージ)の影響を取り除いて(専門的な言葉で言えば「調整して」)病院の良し悪しを論じてみようというのがこの記事の内容になります。しかし、あくまでも「手法」の節で書いた仮定のもとでの推定結果であり、真実として断定するものではありませんのでその点はご理解ください。

なお、影響を取り除く前の(カプランマイヤー法で算出した)実測生存率のヒストグラムは以下になります。病院によってかなりばらつきがありそうに見えます。何もしないで比べると「国立研究開発法人国立がん研究センター中央病院」や「がん研有明病院」など大きな病院の生存率が高いです。患者の平均年齢と平均的な進行度が低いためと思われます。

結果

それでは結果から先に述べます。

以下では病院ごとの生存率や手術率を比較するために、癌種t・病院hにおける男性比率を0.5・平均年齢を60・平均進行度を2.5(おおよそステージIIに相当)に仮に固定して議論をすすめます。

=== 2018.2.6 追記 ===

この部分が分かりにくかったようなので補足します。今知りたいことは「あなたが60歳でステージIIだったらどこへ行くのが5年生存率が高いか」です。しかし、病院によっては若い人が多かったり、重症な人が多かったりして、そのままの生存率を単純に比較できません。そこで、統計モデリングによってすべての病院の平均年齢と平均進行度を仮想的に揃えることで、病院にだけ依存する生存率の差が残るわけです。ここでは一例として平均年齢を60・平均進行度を2.5に固定して算出しています。とにかく揃えればよいので、仮に平均年齢を50・平均進行度を1.5に固定してもよいです(この場合、生存率は全体的に高い方へ少しずれますがランキングは変わりません)。

癌種ごとの生存率のヒストグラム

生存率の推定値(MCMCサンプルの中央値)のヒストグラムは以下になります。

驚くべきことに胃・大腸・肺においてはほとんど病院の差がありません。がん診療連携拠点病院、さすがです。肝癌と乳癌は少し病院によって差がありますので、それぞれbest10を紹介します。ただし、病院による差は提供されたデータ以外の要因(他の病気をもつ患者が多い少ないなど)も含まれます。検討によってそういう情報が重要と分かれば、今後データとして収集して公表する必要があると思います。

 肝癌における生存率best10の病院

黒点は(MCMCサンプルの)中央値、横に伸びる線は80%ベイズ信頼区間です。中央値の高い順に並び替えています。ベイズ信頼区間を見ると分かる通り、病院による差はそこまで明らかではありません。

大垣市民病院」があまりに良いので病院のコメントを読むと「4.肝(C22)は切除症例に限る。」と書いてありました。これだと切除していない患者さんとその死亡数が考慮されていないので、フェアではなく正しいランクとは言えません。ただし、他の癌種における結果を見ると、全てを考慮してもかなり良い生存率である可能性はあります。

仮に「大垣市民病院」を除くと「信州大学医学部附属病院」が1位となります。

 乳癌における生存率best10の病院

グラフの見方は肝癌の場合と同じです。

1位は「愛知県がんセンター中央病院」です。圧倒的な手術率もさることながら、手術以外の影響(病院による効果)も1位です。謎の民間療法に頼らず、こういうところに入院したいものです。

生存率への年齢・進行度・手術率の影響(オッズ比)

点は中央値、横に伸びる線は80%ベイズ信頼区間です。例えば、胃癌の年齢_10のOdds ratioの中央値は約0.5ですが、これは「ある病院における患者の平均年齢が10上がると『生存率/死亡率』が約0.5倍になる」ことを意味します。性別_0.1は男性比率が0.1上がることを意味します。他も同様です。

結果を見ると、予想通り性別の影響はあまりなくて年齢や進行度の影響が強いです。特に胃癌においては年齢の影響が強く、乳癌においては進行度の影響が強いことが示唆されています。

手術率への年齢・進行度の影響(オッズ比)

グラフの見方は上の場合と同様です。肝癌において進行度の影響が低そうなのが意外に思いました。このあたり、病院によって生存率がばらついていることの原因があるのかもしれません。

その他

ここでは記しませんが、他にも色々知ることができました。例を挙げます。

  • 手術の好きな病院・嫌いな病院
  • 手術率が比較的低いわりに生存率がまあまあ高い病院
  • 生存率worst10の病院

手法

データの抽出

このような報告書を公表するのは英断であり、 “がん診療連携拠点病院ががん患者さんの治療に透明性を確保し、拠点病院全体として責任をもって取り組んでいる意気込み” (報告書からの抜粋)をまさに感じることができました(なかには難癖つけて公表を拒否している病院があり最後にまとめました)。一点、ちょっと残念なのは報告書がpdf(580ページほど)であり、利活用が大変しにくいことでした。今回は素敵なRによるスクレイピング入門を出している株式会社ホクソエムの市川さんと牧山代表取締役に担当してもらい、Rのtabulizerパッケージを使って抽出しました。

データの難所

この報告書には188の病院が含まれていて、各病院について5つの癌(胃、大腸、肝、肺、女性乳房)のデータが含まれています。1つの病院の1つの癌のデータ例を以下に示します(数値は架空のものです)。

個人情報保護の観点から、「1人以上10人以下」の場合に-(ハイフン)に置き換えられています。これがこのデータの解析の難易度を大幅に引き上げています。

ハイフンへの対策は以下のようにしました。

  • 合計から数値が計算できる場合は計算して置き換える。
    • 例: 「観血的治療の実施」(=外科的手術の実施)の有と無の合計が「対象数」(=患者数)となるため、上記の例では「観血的治療の実施_無」は100-91=9と算出できる。
    • 例: 「(100.0 - 生存状況把握割合(%))*対象数」を四捨五入して「打ち切り数」が算出できる。
  • 場合の数を列挙し、統計モデリングに組み込む。

生存率はどの値を使うか

打ち切りを考慮してカプランマイヤー法で算出した生存率を使う方法がまず考えられますが、対象数の大小に由来する生存率の推定幅をうまく組み込むのが容易ではありません。また、がん診療連携拠点病院の意気込みが凄くて、現時点で生存状況把握割合は95%以上が多く、今後は100%に近づきそうです。これらを踏まえて、対象数と5年後生存数を用いた二項ロジスティック回帰にしました。*1

対象数が少ない場合への対応

元の報告書では、対象数が50人未満の場合、定された生存率の信頼性が低くなるため公表しないとしてハイフンになっています。本解析においても対象数が少ない場合、年齢や進行度のデータがほぼすべてハイフンとなって、列挙する場合の数が非常に多くなります。そこに推定の時間を取られるのは本意ではないため、同様に50人未満の場合は解析対象から除外しました。

統計モデル

使用したモデルは以下になります。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

青色の項は女性乳房以外の癌において存在し、赤色の項は肺癌のみに存在します。inv_logitはロジスティック関数です。ここでデータは以下です。

  •  t: 癌種のインデックス(tissue)。 Tが癌種の数で、ここでは T=5
  •  h: 病院のインデックス(hospital)。 Hが病院の数で、ここでは報告書にならって患者数が50以上のみを解析対象としたため、 H=183
  •  Surv: 生存数
  •  N: 対象数(患者数)
  •  Ope: 観血的治療の実施有の数(外科的手術をした人数)
  •  Age: 各年代に属する人数
  •  Stage: 各進行度に属する人数
  •  Male: 性別が男性の人数
  •  SC: 肺癌において、部位が小細胞である人数
  •  Cutoff_{age}: 平均 \mu_{age}正規分布を0.5, 0.6, 0.7, 0.8で切って5つの領域の面積をそれぞれ p_{age}に割り当てます。その閾値です。
  •  Cutoff_{stage}: 平均 \mu_{stage}正規分布を0.2, 0.3, 0.4で切って4つの領域の面積をそれぞれ p_{stage}に割り当てます。その閾値です。

ちなみに、発見経緯のデータは進行度に反映されると仮定して今回は使用していません。

推定するパラメータは以下です。推定にはStanとRを用いました。

  •  p_{surv}: 生存率
  •  p_{ope}: 手術率
  •  q_{male}: 男性比率
  •  q_{SC}: 肺癌における小細胞率
  •  \mu_{age}: 癌 t, 病院 hにおける「平均年齢」。*2
  •  \mu_{stage}: 癌 t, 病院 hにおける「平均進行度」。*3
  •  p_{age}: 各年代の比率
  •  p_{stage}: 各進行度の比率
  •  r_{surv}: 生存率に対する説明変数以外の影響。病院差。
  •  r_{ope}: 手術率に対する説明変数以外の影響。病院差。
  •  b,  a: 回帰係数
  •  \sigma_{r_{surv}}: 生存率の病院差の標準偏差。データから推定します(階層モデル)。
  •  \sigma_{r_{ope}}: 手術率の病院差の標準偏差。データから推定します(階層モデル)。

年齢と癌の進行度の影響を取り除くため、これらを説明変数とした二項ロジスティック回帰にするのがポイントです。さらに手術率も年齢や進行度の影響をうけると考えられるので上記のモデルとなっています。 r_{surv} r_{ope}がそれぞれ別の多変量正規分布から生成されるモデルも試してみましたが、現段階では少しデータが足りないようで、MCMCが収束しませんでした。また探索的な解析からは、年齢や進行度によって生存率や手術率が指数関数的に落ちるのではなく線形に近いことが示唆されたため、今回は非線形な関数を含まないモデルとしました。実際のコードは上記の数式に加えてハイフンの処理が入るのでかなり複雑になります。今回はコードの解説を省略します。

最後にモデルの事後予測チェックの図を以下に載せておきます。生存数の実測値と予測値(点は中央値、縦に伸びる線は80%ベイズ信頼区間)は以下です。

生存率のカプランマイヤー法による実測値と予測値(点は中央値、縦に伸びる線は80%ベイズ信頼区間)は以下です。肝癌は少しあてはまりが悪いですが、全体的に問題なく推定できていると思いました。

手術件数の実測値と予測値(点は中央値、縦に伸びる線は80%ベイズ信頼区間)は以下です。

その他にも癌 t, 病院 hにおける「平均年齢」と「平均進行度」の推定などがうまくいっているかもチェック済みです。

後記

謝辞

がん診療連携拠点病院国立がん研究センターがん対策情報センターがん登録センター院内がん登録室に感謝します。相談に乗ってくれた@Med_KUさんと@happyningenさんに感謝します。

展望

全ての癌種の最新かつ時系列のデータで再解析がしたいです。もし患者レベルのデータが利用できれば、さらに良い解析ができます。いつか解析できることを楽しみにしています。

難癖つけて公表しなかった残念な病院一覧

治療のレベルはお察し。毎回更新したいです。

*1:打ち切りがある場合は5年後生存数を正確に得ることはできませんが、生存状況把握割合が非常に高いので、打ち切りに含まれる生存数を「打ち切り数*既知の生存数/対象数」で近似的に算出しました。

*2:本来は患者の年齢と生存か否かが結び付けれればさらに良い解析ができますが、現段階の情報ではこうやって多少強引に平均年齢を算出する他なさそうです。年代の人数は割と山型の分布なので、これで良さそうです。

*3:年代と同様です。しかし、癌種によっては「ステージIとIVは多いけどIIとIIIは少ない」ということがあり、進行度を要約するために平均+正規分布を使うのではなく、他の指標と分布を使った方がよいかもしれません。検討の余地があります。

逆温度1の事後分布のサンプルからWBICを計算する

この記事は以下のツイートを拝見してやってみようと思いました。

ツイートで言及されている渡辺先生の論文は以下です。

  • S Watanabe (2013) "A widely applicable Bayesian information criterion" Journal of Machine Learning Research 14 (Mar), 867-897 (pdf file)

この記事では、以前WAICとLOOCVの比較をした時に使った3つのモデル(重回帰、ロジスティック回帰、非線形回帰)において、「定義通りに算出したオリジナルのWBIC」と「近似式(上記論文の(20)式)で求めたWBIC」を比較してみました。

手法

case 1 重回帰

真のモデルは以下です。

  Y \sim Normal(1.3 - 3.1 X_1 + 0.7 X_2, 2.5)

あてはめたモデルは以下です。

  Y \sim Normal(b_1 + b_2 X_1 + b_3 X_2, \sigma)

データ点の数Nについては20,100を試しました。例としてN = 20の場合を説明します。まず乱数でデータX(すなわち X_1, X_2)を生成します。次にそのXの値を使ってYを生成しますが、以下の二つの場合について計算しました。

  • 1) MCMCサンプルの出方の違いによる影響: Yを1つだけ生成して固定し、MCMCの乱数の種を変えて200回推定を行い、 WBIC_{original} WBIC_{approx}のそれぞれの分布を確認した
  • 2) データの出方の違いによる影響: Yを1000通り生成し、 WBIC_{approx} - WBIC_{original}および相対的な差である (WBIC_{approx} - WBIC_{original})/WBIC_{original}の分布を確認した

事後分布の推定はStanで行いました。iter=11000, warmup=1000, chains=4で実行して合計40000個のMCMCサンプルを得ています。

case 2 ロジスティック回帰

手順は重回帰の場合と同じです。使用したモデルだけが異なります。真のモデルは以下です。

  Y \sim Bernoulli(inv\_logit(0.8 - 1.1 X_1 + 0.1 X_2))

あてはめたモデルは以下です。

  Y \sim Bernoulli(inv\_logit(b_1 + b_2 X_1 + b_3 X_2))

  b_1,b_2,b_3 \sim Student\_t(4,0,1)

case 3 非線形回帰 ミカエリス・メンテン型

手順は重回帰の場合と同じです。使用したモデルだけが異なります。真のモデルは以下です。

  Y \sim Normal(10.0 X / (2.0 + X), 0.8)

あてはめたモデルは以下です。

  Y \sim Normal(m X / (k + X), \sigma)

  k \sim Uniform(0, 12)

  m \sim Uniform(0, 20)

case 3b 真のモデルが含まれない場合

あてはめたモデルが以下の場合も試しました。

  Y \sim Normal(a + b X, \sigma)

結果

計算速度

Stanを使う場合、近似式のモデルの方がサンプリングが速いので、計算速度は近似式の方が少し速いです。どれだけ速くなるかはモデル依存で場合によりけりです。

MCMCサンプルの出方の違いによる影響

近似式の方はMCMCサンプルの出方によってかなりばらつくようです。また、少し値が低くなっています。

データの出方の違いによる影響

横軸は WBIC_{approx} - WBIC_{original}です。少しマイナスに偏った分布になりました。これが近似で捨てた項の影響なのか、Stanによるサンプリングの影響なのかは分かりません。

相対的な差

横軸は相対的な差である (WBIC_{approx} - WBIC_{original})/WBIC_{original} * 100です。ロジスティック回帰のN = 20の場合は、しばしば WBIC_{original}が0に近くなるので尾を引いています。Nが増えるに従って相対的な差は小さくなり、N = 100では±5%ぐらいに収まりそうです。

まとめ

Nが大きいときは近似式でスピードを重視しても大丈夫そう。でもNが小さいときは定義通り計算した方が無難に思えます。

ソースコード

case 1ソースコードを以下に載せます。

オリジナルのWBICを算出するためのStanコード

model/model1-ori.stanというファイル名とします。

data {
  int D;
  int N;
  matrix[N,D] X;
  vector[N] Y;
}

parameters {
  vector[D] b;
  real<lower=0> sigma;
}

transformed parameters {
  vector[N] mu;
  mu = X*b;
}

model {
  target += 1/log(N) * normal_lpdf(Y | mu, sigma);
}

generated quantities {
  vector[N] log_lik;
  for (n in 1:N)
    log_lik[n] = normal_lpdf(Y[n] | mu[n], sigma);
}

近似式でWBICを算出するためのStanコード

model/model1-apx.stanというファイル名とします。

data {
  int D;
  int N;
  matrix[N,D] X;
  vector[N] Y;
}

parameters {
  vector[D] b;
  real<lower=0> sigma;
}

transformed parameters {
  vector[N] mu;
  mu = X*b;
}

model {
  Y ~ normal(mu, sigma);
}

generated quantities {
  vector[N] log_lik;
  for (n in 1:N)
    log_lik[n] = normal_lpdf(Y[n] | mu[n], sigma);
}

各WBICを算出するRコード

library(rstan)

wbic_original <- function(log_lik) {
  wbic <- - mean(rowSums(log_lik))
  return(wbic)
}

wbic_approx <- function(log_lik) {
  b2 <- 1.0/log(ncol(log_lik))
  b1 <- 1.0
  log_denominator <- statnet.common::log_sum_exp(-(b2-b1)*(rowSums(-log_lik)))
  log_numerator   <- statnet.common::log_sum_exp(-(b2-b1)*(rowSums(-log_lik)) + log(rowSums(-log_lik)))
  wbic <- exp(log_numerator - log_denominator)
  return(wbic)
}

set.seed(123)
D <- 3
b <- c(1.3, -3.1, 0.7)
SD <- 2.5
N <- 100

X <- cbind(1, matrix(runif(N*(D-1), -3, 3), N, (D-1)))
Mu <- X %*% b
Y <- rnorm(N, Mu, SD)
data <- list(N=N, D=D, X=X, Y=Y)

sm_ori <- stan_model(file='model/model1-ori.stan')
sm_apx <- stan_model(file='model/model1-apx.stan')
fit_ori <- sampling(sm_ori, pars='log_lik', data=data, iter=11000, warmup=1000, seed=123)
fit_apx <- sampling(sm_apx, pars='log_lik', data=data, iter=11000, warmup=1000, seed=123)
wbic_ori <- wbic_original(rstan::extract(fit_ori)$log_lik)
wbic_apx <- wbic_approx(rstan::extract(fit_apx)$log_lik)
c(wbic_ori=wbic_ori, wbic_apx=wbic_apx)
  • 3~6行目:  WBIC_{original}を算出します。この記事参照。
  • 8~15行目:  WBIC_{approx}を算出します。途中で{statnet.common}パッケージのlog_sum_exp関数を使っています。前の記事のようにStanに含まれるlog_sum_exp関数を使っても構いません(全く同じ数値になります)。
  • 9行目: log_likN_mcmc×N(データの数)のmatrix型ですのでncol(log_lik)Nを取得しています。

渡辺先生の論文の(20)式の通りに計算しようとすると、expの内側が50ぐらい以上の数値になるため、計算が不安定になります。そのため(20)式の両辺の対数をとって計算して、最後にexpをかませて戻します。MCMCを使っている場合、(20)式の左辺の対数は以下のように式変形できます。

   

なので、まず易しい分母の方から計算すると、

 

 

 

 

 

分子の方も同様に、

 

 

これをそのまま実装しています。

Enjoy!