ゆるふわクオンツの日常

カルマンフィルタのPython実装とnumbaによる高速化のお話

時系列系の分析をしていく中で、状態空間モデルを使うこともあるかと思います。

その際、シンプルなカルマンフィルタであってもなかなか複雑だったりして

思うようにコーディングできなかったりします。

そこで、Python実装時のforループ劇遅問題含めてまとめてみました。

pythonでカルマンフィルタといえばPykalmanもありますが、

最近メンテされてなさそうなので...)

カルマンフィルタとは

線形・ガウス型の状態空間モデルといわれるやつで、

状態空間モデルの中でも比較的シンプルなものです。

モデルとしては、状態方程式と観測方程式の2本の式から構成されます。

 x_t = Tx_{t-1} + \epsilon_t

 y_t = Hx_t+\eta_t

ここで、 x_tは観測できない状態変数を表しており、

観測できるのは y_tだけというモデルになります。

この設定において、入手できたデータ y_0, y_1,...,y_nから

 \eta, \epsilonの分散や T, Hといったパラメータ推定を行ったり、

パラメータが分かっていればその状態変数を求めたり(フィルタリング)

など、いろいろな活用法があります。

どの処理をコーディングしたか

(下記関連図書を踏襲して)予測とフィルタリングを行う関数を実装しました。

引数は

  1. 前期の状態変数のフィルタリング分布( x_{t-1|t-1}, V_{t-1|t-1} )

  2. パラメータ( T, V[ \epsilon ], H, V[ \eta ])

  3. 今期の観測変数 y_t

で、アウトプットは

  1. 今期の観測変数の予測値 y_{t|t-1}

  2. 状態変数のフィルタリング分布( x_{t|t}, V_{t|t} )

となっています。

これを逐次的にデータに適用することで、

所定のパラメータ下の尤度を計算できるので、

scipy.minimizeなどで最適化することで 、

最尤法でパラメータ推定などもできます。

計算ロジックの概要

上記のコーディングを行うにあたっての計算ロジックは以下の通りです。

入力

 x_{t-1|t-1}, V_{t-1|t-1}, T, H, V[ \epsilon ], V[ \eta ], y_t

step.1 今期の状態変数の予測分布

 x_{t|t-1} = Tx_{t-1|t-1}

 V_{t|t-1} = TV_{t-1|t-1}T^{T}+V[ \epsilon ]

step.2 今期の観測変数の予測値

 y_{t|t-1} = H x_{t|t-1}

 y_{t|t-1}の分散 =HV_{t|t-1}H^{T}+V[ \eta ]

step.3 今期の状態変数のフィルタリング分布

カルマンゲインという数値を以下のように設定し
 K_{t} = V_{t|t-1}H^{T}(HV_{t|t-1}H^{T}+V[ \eta ])^{-1}

状態変数に関して以下が得られます
 x_{t|t}=x_{t|t-1}+K_{t}(y_{t}-y_{t|t-1})

 V_{t|t}=(I-K_{t}H)V_{t|t-1}

出力

 x_{t|t}, V_{t|t},y_{t|t-1}, HV_{t|t-1}H^{T}+V[ \eta ]

コード

可読性が劣悪ですが、下記のようなコードを作成しました。

おそらく、この関数をfor文などでループ処理することが多いと思いますので

関数のコンパイルをよしなにやってくれるnumbaを使って高速化してあります。

関数の上に @njit を付けると早くなります。

(numbaはnumpyに特化したコンパイラです。対応してない関数もあるけど...)

import numpy as np
from numba import njit


@njit    
def Num_Dot(a: np.ndarray, b: np.ndarray):
    """
    np.dotを使えば良いけど、numbaの練習がてら
    """
    c = np.zeros((a.shape[0], b.shape[1]))
    for i in range(a.shape[0]):
        for j in range(b.shape[1]):
            for k in range(a.shape[1]):
                c[i][j] += a[i][k] * b[k][j]
    return c


@njit
def kalman_filter(
    state_variable_mean_pre: np.ndarray, state_variable_covariance_matrix_pre: np.ndarray, observation_matrix: np.ndarray,
    observation_error_covariance_matrix: np.ndarray, state_transition_matrix: np.ndarray, state_error_covariance_matrix: np.ndarray,
    observations: np.ndarray
    ):
    """
    t-1のフィルタリング分布(平均と分散共分散行列)を受け取り,t時点のフィルタリング分布(平均と分散共分散行列)を返します.
    """
    state_variable_dim = state_variable_mean_pre.shape[0]
    
    # 1期先予測状態変数
    state_variable_mean = Num_Dot(state_transition_matrix, state_variable_mean_pre)
    state_variable_covariance_matrix = Num_Dot(Num_Dot(state_transition_matrix, state_variable_covariance_matrix_pre), state_transition_matrix.T) + state_error_covariance_matrix
    
    # 1期先予測観測変数
    observation_variable_mean = Num_Dot(observation_matrix, state_variable_mean)
    observation_variable_covariance_matrix = Num_Dot(Num_Dot(observation_matrix, state_variable_covariance_matrix), observation_matrix.T) + observation_error_covariance_matrix
    
    # kalman gain
    kalman_gain = Num_Dot(Num_Dot(state_variable_covariance_matrix, observation_matrix.T), np.linalg.inv(observation_variable_covariance_matrix))
    
    # フィルタリング状態変数
    filtered_state_variable_mean = state_variable_mean + Num_Dot(kalman_gain, (observations - observation_variable_mean))
    filtered_state_variable_covariance_matrix = Num_Dot((np.eye(state_variable_dim) - Num_Dot(kalman_gain, observation_matrix)), state_variable_covariance_matrix)
    
    return filtered_state_variable_mean, filtered_state_variable_covariance_matrix, observations - observation_variable_mean, observation_variable_mean, observation_variable_covariance_matrix

今回の関連図書

Rではありますが、カルマンフィルタ部分の実装などが細かく載っており参考になりました。
またパラメタの最尤推定時のコードも載っており、method=BFGSを用いるのも参考になりました。

こちらは理論面に重点を置きながらときおりコードを交えて説明してくれています。

pythonでの状態空間モデルやHMMなどを含む時系列モデリングについて載っています。
Pykalmanについても載っています。
時系列はRのものが多いのでPythonのものは珍しい印象です。