ひつじTips

技術系いろいろつまみ食います。

jupyter+matplotlibで3Dグラフを書くとき,ipywidgetsを使ってインタラクティブに数値を変える方法

f:id:mu-777:20200209150435g:plain

jupyter+matplotlibでグラフを書くときに,ipywidgetsのinteract関数を使って数値をインタラクティブに変更する方法がよく紹介されている(下記リンク参照)のですが,Axes3Dでグラフを作るときに結構ハマったので書いておく.

結論

上のgifを実現してるコードを載せる.

%matplotlib inline
# %matplotlib notebook  # notebookでは画が出ない

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact

fig = plt.figure()
ax = fig.gca(projection='3d')

def scatter(num_data):
    x = range(num_data)
    y = [np.sin(t/10.0) for t in x]
    z = [np.cos(t/10.0) for t in x]
    plt.gcf().gca(projection='3d').plot(x, y, z, 'o',color='C3')
    plt.show()

interact(scatter, num_data=(1, 300, 1))

キモは,interact関数に渡しているコールバック関数内で,pltから現在のAxes3Dを取得しなおしているところ.

interact関数を実行する前に作ったAxes3Dのaxは,コールバック関数内に渡せない(のか?? 下記「試してうまくいかなかったこと」では渡すことができてるようにも見えるが...).コールバック関数はstatic的に記述できるようにしておかねばならない.

なので,以下のように global 宣言を使っても期待通り機能する.

%matplotlib inline
# %matplotlib notebook

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact


def scatter(num_data):
    fig = plt.figure()
    global ax
    ax = fig.gca(projection='3d')
    x = range(num_data)
    y = [np.sin(t/10.0) for t in x]
    z = [np.cos(t/10.0) for t in x]
    ax.plot(x, y, z, 'o',color='C3')
    plt.show()

interact(scatter, num_data=(1, 300, 1))

試してうまくいかなかったこと

クラス化

これを実行すると,最初の初期値でのグラフが表示されるが,スライダーを動かしたら消える..

print(self._name, num_data, self._ax.get_zlabel())scatter関数に入っているが,ここは期待通りに,Interacterクラスの__init__で定義した "test" とか "aaaaa" が表示されているので,selfが渡っていないわけではないが..

%matplotlib inline
# %matplotlib notebook

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact


class Interacter(object):
    def __init__(self):
        self._name = "test"
        self._fig = plt.figure()
        self._ax = self._fig.gca(projection='3d') 
        self._ax.set_zlabel("aaaaa")

    def scatter(self, num_data):
        print(self._name, num_data, self._ax.get_zlabel())
        x = range(num_data)
        y = [np.sin(t/10.0) for t in x]
        z = [np.cos(t/10.0) for t in x]
        self._ax.plot3D(x, y, z,  'o',color='C3')
        plt.show()
        
interacter = Interacter()
interact(interacter.scatter, num_data=(1, 300, 1))

これ地味に解せない...

2Dグラフのサンプルの拡張

これも上と同様に,最初はグラフが表示されるが,スライダーを動かしたら消える. print(ax.get_xlabel(), num_data) では,期待通りに ax.set_xlabel("aaa") と設定した文字列"aaa"が表示されるにも関わらず...

%matplotlib inline
# %matplotlib notebook

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact

fig = plt.figure()
ax = fig.gca(projection='3d') 
ax.set_xlabel("aaa")

def scatter(num_data):
    x = range(num_data)
    y = [np.sin(t/10.0) for t in x]
    z = [np.cos(t/10.0) for t in x]
    print(ax.get_xlabel(), num_data)
    ax.plot3D(x, y, z,  'o',color='C3')
    plt.show()

interact(scatter, num_data=(1, 300, 1))

感想

matplotlib 難しすぎて使うたびにググってる気がする...ソース読めばこの挙動の理由がわかる気もするがめんどい...