def get_dataset():
list_folder = os.listdir('data/')
list_images = []
for i in xrange(len(list_folder)):
images = os.listdir('data/' + list_folder[i])
for x in xrange(len(images)):
image = [list_folder[i] + '/' + images[x], list_folder[i]]
list_images.append(image)
list_images = np.array(list_images)
np.random.shuffle(list_images)
print "before cleaning got: " + str(list_images.shape[0]) + " data"
list_temp = []
for i in xrange(list_images.shape[0]):
image = misc.imread('data/' + list_images[i, 0])
if len(image.shape) < 3:
continue
list_temp.append(list_images[i, :].tolist())
list_images = np.array(list_temp)
print "after cleaning got: " + str(list_images.shape[0]) + " data"
label = np.unique(list_images[:, 1]).tolist()
list_images[:, 1] = LabelEncoder().fit_transform(list_images[:, 1])
return list_images, np.unique(list_images[:, 1]).shape[0], label
train.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录