def get_val_caps_batch(batch_size, loaded_data, data_set, data_dir):
if data_set == 'flowers':
captions = np.zeros((batch_size, loaded_data['max_caps_len']))
batch_idx = np.random.randint(0, loaded_data['val_data_len'],
size = batch_size)
image_ids = np.take(loaded_data['val_img_list'], batch_idx)
image_files = []
image_caps = []
for idx, image_id in enumerate(image_ids) :
image_file = join(data_dir,
'flowers/jpg/' + image_id)
random_caption = random.randint(0, 4)
captions[idx, :] = \
loaded_data['val_captions'][image_id][random_caption][
0 :loaded_data['max_caps_len']]
image_caps.append(loaded_data['str_captions']
[image_id][random_caption])
image_files.append(image_file)
return captions, image_files, image_caps, image_ids
else:
raise Exception('Dataset not found')
评论列表
文章目录