えんじにあのじゆうちょう

勉強したことを中心にアウトプットしていきます。

matplotlibのグラフをアニメーションにする

はじめに

例えば、単純パーセプトロンの学習過程をアニメーション化したいという話は説明で使いたいときなどによくあると思います。
そんなときにどうやればよいかを簡単にまとめました。

実装

例題

今回ははじめにに書いたとおり、単純パーセプトロンの学習を例に取ります。

まず、学習結果と重みを受け取って結果を返す関数は以下のようにします。

def f(w, x):
    return w[0] + w[1] * x

また、学習率のパラメータ類と実際の学習は以下の通りに定義します。
学習中のパラメータはw_historyとしてイテレーションごとに記録します。

eta = 0.001 # 学習率
w = [-3,3] # 重み(初期値は適当)
max_iter = 10000 # 最大繰り返し回数

w_history = [w] # 重み記録用

for i in range(0, max_iter):
    y_hat = f(w, x[:,1])
    grad = np.dot((y - y_hat), x)
    w = w + eta * grad
    w_history.append(w)

インポート

通常、matplotlibは

import matplotlib.pyplot as plt

と言った感じの1行をインポートすれば概ね事足りますが、追加で以下のアニメーションの方を読み込ませる必要があります。

import matplotlib.animation as animation

描画領域と初期値の用意

いきなりアニメーションしようとしてもうまくいきません。
まずは最初に描画領域と初期値を突っ込んだ直線だったりを書いておく必要があります。

fig = plt.figure(1)

ax = fig.add_subplot()
ax.grid(which="both")

txt = ax.text(0, 0, "iteration: 0")
x_range = np.arange(0,10,0.1)
ax, = ax.plot(x_range,f(w_history[0], x_range))

とりあえず適当にiteration回数とw_historyの1番最初の要素で計算した結果を描画しておきます。

アップデート関数の用意

関数を用意する方法が個人的にはわかりやすかったので、その方法を説明します。
関数名は何でも良いのですが、一旦以下のようにします。

def update(i):
    ax.set_ydata(f(w_history[i], x_tick))
    txt.set_text("iteration: {}".format(i))

update関数は呼び出し時にiteration回数を受けとります。
それをキーにしてw_historyのi番目を呼び出すようにしています。

アニメーションの作成

では最後に、アニメーションにしていきます。

anim = animation.FuncAnimation(fig, update,100, interval=300)
ani.save('./perceptron.mp4')

animation.FuncAnimationを呼び出します。
第1引数は描画オブジェクト、第2引数がupdate用の関数、第3引数が何iteration実行するか、intervalが繰り返しの間のインターバル(ミリ秒)です。
そして最後にsaveメソッドを呼び出すことで指定したファイル名でアニメーションを保存できます。