Pythonのメモ帳

numpy, pandas, tensorflow を使いこなすための忘備録

配列から重複なく要素をランダムに抽出し2組に分ける(データセットをシャッフルしつつ2つに分割)

ディープラーニングの前処理としてよくあるのが、データをシャッフルしつつ2つに分ける作業。言い換えると、データセット(配列)から学習データとテストデータを重複なくランダムに取り出す作業。

処理としては、ランダムな配列を作り、それをインデックスとして使ってデータを抽出する。

import numpy as np

num_test  = 10000
num_train = 10000
num_all   = num_train+num_test

id_all   = np.random.choice(num_all, num_all, replace=False)
id_test  = id_all[0:num_test]
id_train = id_all[num_test:num_all]
test_data  = mydataset[id_test]
train_data = mydataset[id_train]

 

下記のようにデータセット自体をシャッフルする方法でも、同様の処理はできる。直感的にわかりやすいけど、処理は遅かった。

import numpy as np

num_test  = 10000
num_train = 10000
num_all   = num_train+num_test

np.random.shuffle(mydataset)
test_data  = mydataset[0:num_test]
train_data = mydataset[num_test:num_all]