ゆるふわクオンツの日常

クロスバリデーションとすると汎化性能が良くなる仕組み覚書

データ分析をする際は、交差検証(クロスバリデーション)してパラメータを決めることが多いと思います。

クロスバリデーションすることでテストデータでの精度が上がる傾向があるのは経験的にも感覚的にもなんとなくわかるのですが、

それがどういう原理に基づいているのかを確認していきたいと思います。

用語の定義

とりあえず今回の議論は損失関数として対数尤度をマイナスしたものを想定します。

また、クロスバリデーションは「サンプルの中から一つだけ除外して推測を行い、除外したサンプルで評価を行う」クロスバリデーション(LOOCV)を想定します。

そして、サンプルは X_{1}, X_{2}, ... , X_{n} iid

観測できない何らかの真の分布に従っているとします。

  • クロスバリデーション損失 \displaystyle C_{n} = -\frac{1}{n}\sum_{i=1}^{n}lnE_{w}^{(i)}[p(X_{i},w)] と書けます。ここで、 E_{w}^{(i)}[・]は X_{i}を除いたサンプルで学習した事後分布による期待値をで表しており、 E_{w}^{(i)}[p(X_{i},w)]は X_{i}の予測分布( =p(X_{i}|X_{1},...X_{i-1},X_{i+1},...,X_{n}))になっています。

  • 経験損失 \displaystyle T_{n} = -\frac{1}{n}\sum_{i=1}^{n}lnE_{w}[p(X_{i},w)]と書けます。ここで、 E_{w}[・]は全サンプルで学習した事後分布による期待値です。

  • 汎化損失 \displaystyle G_{n} = -E_{X}[lnE_{w}[p(X,w)]と書けます。ここで、 E_{w}[・]は経験損失同様に n個のサンプルから学習した事後分布による期待値です。で、 E_{X}[f(X)]はサンプル外のデータ X(わかりにくければ X_{n+1}とみてもOK)と観測できない真の分布による期待値です。ポイントとしては、 G_{n} Xでの期待値をとっているんだけれど、実はこれは X_{1},...,X_{n}に依存した確率変数であるということです(予測分布が p(X|X_{1},...,X_{n})であることを考えればわかりますかね)。

クロスバリデーションと汎化損失の関係

以上のように定義された3つの確率変数ですが、

特にクロスバリデーション損失と汎化損失の間に成り立つ関係があります。

まず、 E[C_{n}]について少し考えると、

ログの部分を一旦無視すると、確率変数 E_{w}^{(i)}[p(X_{i},w)]に対する期待値となっています。

この確率変数は n-1個のサンプルから学習された事後分布を用いて生成された予測分布から

 i番目のサンプルがgenerateされる確率を表す確率変数となっています。

これの期待値をとっているわけですから、

 E[C_{n}]というのは、

 n-1個のサンプルから学習した分布から、未知の n番目のサンプルが発生する確率の平均値」

を表していると言えそうです。

こう考えたときに、

n-1個から学習して未知のn個目が発生する確率」

というのは汎化損失のところで出てきた考え方と同じであることがわかり、

 E[C_{n}]=E[G_{n-1}]

の等式が成立することがわかります。

細かく書くと

 E[C_{n} ]=-\frac{1}{n}\sum_{i=1}^{n}E[lnE_{w}^{(i)}[p(X_{i},w)] ]

 rhs = -E[lnE_{w}^{(n)}[p(X_{n},w)] ] = E_{X_{1},...,X_{n-1}}[-E_{X_{n}}[ lnE_{w}^{(n)}[p(X_{n},w)] ] ] = E[G_{n-1}]

って感じですかね。

これを見ると、

冒頭に述べたような、クロスバリデーションのロスを減らすようにパラメータ wを学習することは、

右辺の(期待)汎化損失を減らそうと頑張っていたことを意味します。

ちなみに、上記の期待値があのWAICの期待値に漸近的に一致する的な話もあります( E[G_{n-1}] = E[WAIC_{n-1}]+o(\frac{1}{n}))。

経験損失って?

ところで、

クロスバリデーションが汎化性能の意味で正義っぽいことは分かったとして、

じゃあ経験損失の存在意義って?ってなりますよね。

予測をする上で経験損失はあんまり使えないのですが笑、

小ネタとして以下が言えます。

 T_{n} \le C_{n}

詳しくみていくと、

 \displaystyle C_{n} = -\frac{1}{n}\sum_{i=1}^{n}ln\frac{\int p(X_{i},w) \prod_{j \neq i} p(X_{j}, w) \phi(w) dw}{\int  \prod_{j \neq i} p(X_{j}, w) \phi(w) dw}

 \displaystyle \ \ \ \  = -\frac{1}{n}\sum_{i=1}^{n}ln\frac{\int \prod_{j = 1}^{n} p(X_{j}, w) \phi(w) dw}{\int \frac{1}{p(X_{i},w)}  \prod_{j = 1}^{n} p(X_{j}, w) \phi(w) dw}

 \displaystyle \ \ \ \  = -\frac{1}{n}\sum_{i=1}^{n}ln\frac{1}{E_{w}[\frac{1}{p(X_{i},w)}]}

 \displaystyle \ \ \ \  = \frac{1}{n}\sum_{i=1}^{n}lnE_{w}[\frac{1}{p(X_{i},w)}]

となることに着目すると、

 \displaystyle C_{n}-T_{n} = \frac{1}{n}\sum_{i=1}^{n}ln(E_{w}[\frac{1}{p(X_{i},w)}]E[p(X_{i},w)])

ヘルダーの不等式より

 \displaystyle rhs \ge \frac{1}{n}\sum_{i=1}^{n}ln1 = 0

ということで C_{n} \ge T_{n}とわかりました。

というわけで今日はこの辺で。

今回の関連書籍

むずいですが、2年くらい眺め続けたら朧げにわかってきました。
ま、代数幾何はよくわからないのですが笑

WBICも載ってるのはこちら