train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TAC-GAN 作者: dashayushman 项目源码 文件源码
def get_val_caps_batch(batch_size, loaded_data, data_set, data_dir):

    if data_set == 'flowers':
        captions = np.zeros((batch_size, loaded_data['max_caps_len']))

        batch_idx = np.random.randint(0, loaded_data['val_data_len'],
                                      size = batch_size)
        image_ids = np.take(loaded_data['val_img_list'], batch_idx)
        image_files = []
        image_caps = []
        for idx, image_id in enumerate(image_ids) :
            image_file = join(data_dir,
                              'flowers/jpg/' + image_id)
            random_caption = random.randint(0, 4)
            captions[idx, :] = \
                loaded_data['val_captions'][image_id][random_caption][
                0 :loaded_data['max_caps_len']]

            image_caps.append(loaded_data['str_captions']
                              [image_id][random_caption])
            image_files.append(image_file)

        return captions, image_files, image_caps, image_ids
    else:
        raise Exception('Dataset not found')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号