matplotlibでヒートマップを高速描画する簡単な方法
Pythonでのヒートマップの描き方を調べると、だいたい以下の2つの方法が出てくる。
- seabornライブラリのheatmap関数を使う方法
Seaborn でヒートマップを作成する – Python でデータサイエンス
- matplotlibライブラリのpcolor関数を使う方法
Python + matplotlib によるヒートマップ
これらの真っ当なヒートマップは、データの大きいと描画に時間がかかるのが欠点。特にseabornは時間がかかる。seabornは使い慣れたら便利そうだしグラフの見た目も綺麗なだけに残念。
速くしたい時は以下のコードがおすすめ。
結果:
画像描画のimshowを使い、引数を interpolation='nearest' とすることで、ヒートマップと同じ表示ができる。この引数を設定しないと、データ間を補間して表示してしまう(データ間の境界がはっきりしなくなる)。
100万画素のデータ(行数1000 x 列数1000)のデータの場合で計算時間を計測してみたら、以下のように圧倒的な速さだった。
- seaborn.heatmapを使った場合 : 35 sec
- matplotlib.pcolorを使った場合 : 6 sec
- matplotlib.imshowを使った場合 : 0.5 sec
欠点はグラフの縦横比が固定されるので、plt.figure(figsize=(*,*)) だけでは調節できないことくらい。これの対処法は以下の記事を参照。
ちなみに各行列のラベル表記はしない前提。というか行数・列数多くてラベルをつけることが無理なような大きいデータを使うから関係ないですね。