カレーちゃんブログ

Kaggleや競技プログラミングなどのこと

PyStan入門してみた(2)

前回から時間が経ちましたが、PyStan入門してみたシリーズの2回目です。 今回は1回目の内容に加えて、スタンモデル等の保存とモデルの読み込みをについての軽い内容です。 1回目の内容を含んだものをgistにあげました。これを見たほうが早いという方も多いでしょう。

前回の復習と今回の内容

sm = pystan.StanModel(file='liner.stan')
fit = sm.sampling(data=stan_data, iter=3000, chains=3,thin=1)

PyStanのMCMCは、上記のようにして

  • 1行目で.stanファイルのコンパイルを行い
  • 2行目で、コンパイルしたStanModelに、辞書型のデータを入れてサンプリングをする

という流れでした。

コンパイルもサンプリングも、モデルや環境にもよりますが時間がかかるので、毎度毎度やるのは大変です。
そこで、コンパイルしたStanModel及びサンプリング結果を保存しておいて、あとで好きな時に呼び出す仕組みが用意されています。 公式では、Avoiding recompilation of Stan models — PyStan 2.17.0.0 documentationに書いてあるので、正確には公式を参照ください。

今回の本題

コンパイルしたモデルの保存

import pickle
# データの書き込み
with open('sumple.pickle', 'wb') as f:
    pickle.dump(sm, f)
    pickle.dump(fit, f)

Pythonの標準ライブラリにあるpickleモジュールを使い、コンパイル済みのスタンモデル(sm)や、サンプリング結果(fit)を、保存します。

  • 上記のコードでは'sample.pickle'ファイルに'sm'と'fit'オブジェクトを、保存している。

注意点としては、pickleでの保存は上書き出来ないので、再度保存する場合は、sumple.plckleファイルを削除するか、ファイル名を変更(上記のコードのsample.picklesample.pickle2に変更する意味)するなどする必要があります。

公式ドキュメントに記載がある、上記の注意点を気にしなくても良い方法は下部に記載しています。

コンパイルしたモデルの読み込み

with open('sumple.pickle', 'rb') as f:
    sm = pickle.load(f)
    fit = pickle.load(f)

sumple.pickleに保存してある、smオブジェクトとfitオブジェクトを呼び出す例です。後日使いたい時にいつでもすぐに呼び出せるので非常に楽です。

公式に記載がある注意点を気にしなくても良い方法

以下のように書くと、.stanコードが違えば、別の名前をつけてくれるし、既に同じ.stanコードをコンパイルしたものがある場合は、"Using cached StanModel"と表示がされます。
このやり方は、コンパイルする際にこの関数('StanModel_cache')を呼び出せば、自動的に保存までされるので、常にこの方法を使うのが楽そうです。

import pystan
import pickle
from hashlib import md5

def StanModel_cache(model_code, model_name=None, **kwargs):
    """Use just as you would `stan`"""
    code_hash = md5(model_code.encode('utf-8')).hexdigest() # 公式の'ascii'ではうまくいかないので、'utf-8'に変更
    if model_name is None:
        cache_fn = 'cached-model-{}.pkl'.format(code_hash)
    else:
        cache_fn = 'cached-{}-{}.pkl'.format(model_name, code_hash)

    try:
        sm = pickle.load(open(cache_fn, 'rb'))
    except: #tryで例外が発生すれば、smを書き込み
        sm = pystan.StanModel(model_code=model_code)
        with open(cache_fn, 'wb') as f:
            pickle.dump(sm, f)
    else: # tryで例外が発生しなければ、"Using cached StanModel"をプリント
        print("Using cached StanModel")
    return sm

sm = StanModel_cache(model_code=model_code)

まとめ

難しいことはなく、毎回コピペして実行すれば保存と読み込みは楽勝です。 次回は、pythonからRのパッケージを便利に使うにしたいと思います。