def get_caption_batch(loaded_data, data_dir, dataset='flowers', batch_size=64):
captions = np.zeros((batch_size, loaded_data['max_caps_len']))
batch_idx = np.random.randint(0, loaded_data['data_length'],
size=batch_size)
image_ids = np.take(loaded_data['image_list'], batch_idx)
image_files = []
image_caps = []
image_caps_ids = []
for idx, image_id in enumerate(image_ids):
image_file = join(data_dir, dataset, 'jpg' + image_id)
random_caption = random.randint(0, 4)
image_caps_ids.append(random_caption)
captions[idx, :] = \
loaded_data['captions'][image_id][random_caption][
0:loaded_data['max_caps_len']]
image_caps.append(loaded_data['captions']
[image_id][random_caption])
image_files.append(image_file)
return captions, image_files, image_caps, image_ids, image_caps_ids
评论列表
文章目录