def __init__(self, root, split):
if split not in ['train', 'test', 'all']:
raise ValueError
dir = os.path.join(root, split)
filenames = glob.glob(os.path.join(dir, '*.png'))
if split == 'all':
filenames = glob.glob(os.path.join(root, 'train/*.png'))
filenames.extend(glob.glob(os.path.join(root, 'test/*.png')))
filenames = sorted(
filenames, key=lambda x: int(os.path.basename(x).split('.')[0]))
images = []
for f in filenames:
img = plt.imread(f)
img[img != 1] = 0
images.append(resize(rgb2gray(img), [48, 48], mode='constant'))
self.images = np.array(images, dtype=np.float32)
self.images = self.images.reshape([len(images), 48, 48, 1])
action_filename = os.path.join(root, 'actions.txt')
with open(action_filename) as infile:
actions = np.array([float(l) for l in infile.readlines()])
self.actions = actions[:len(self.images)].astype(np.float32)
self.actions = self.actions.reshape(len(actions), 1)
评论列表
文章目录