Pythonのメモ帳

numpy, pandas, tensorflow を使いこなすための忘備録

matplotlibでヒートマップを高速描画する簡単な方法

Pythonでのヒートマップの描き方を調べると、だいたい以下の2つの方法が出てくる。

 

- seabornライブラリのheatmap関数を使う方法
   Seaborn でヒートマップを作成する – Python でデータサイエンス

 - matplotlibライブラリのpcolor関数を使う方法
   Python + matplotlib によるヒートマップ

 

これらの真っ当なヒートマップは、データの大きいと描画に時間がかかるのが欠点。特にseabornは時間がかかる。seabornは使い慣れたら便利そうだしグラフの見た目も綺麗なだけに残念。

  

速くしたい時は以下のコードがおすすめ。

import numpy as np
import matplotlib.pyplot as plt

# 10x10のダミーデータ作成
mat = np.random.rand(10,10)

# ヒートマップ表示
plt.figure()
plt.imshow(mat,interpolation='nearest',vmin=0,vmax=1,cmap='jet')
plt.colorbar()
plt.show()

結果:

 f:id:spcx8:20181006075821p:plain 

 

画像描画のimshowを使い、引数を interpolation='nearest' とすることで、ヒートマップと同じ表示ができる。この引数を設定しないと、データ間を補間して表示してしまう(データ間の境界がはっきりしなくなる)。

 

100万画素のデータ(行数1000 x 列数1000)のデータの場合で計算時間を計測してみたら、以下のように圧倒的な速さだった。

  • seaborn.heatmapを使った場合 : 35 sec
  • matplotlib.pcolorを使った場合 : 6 sec
  • matplotlib.imshowを使った場合 : 0.5 sec

 

欠点はグラフの縦横比が固定されるので、plt.figure(figsize=(*,*)) だけでは調節できないことくらい。これの対処法は以下の記事を参照。

spcx8.hatenablog.com

 

ちなみに各行列のラベル表記はしない前提。というか行数・列数多くてラベルをつけることが無理なような大きいデータを使うから関係ないですね。