valid_time.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:monogreedy 作者: jinjunqi 项目源码 文件源码
def validate(beam_searcher, dataset, logger=None, res_file=None):
    if logger is None:
        logger = Logger(None)
    # generating captions
    all_candidates = []
    tic = time.clock()
    for i in xrange(dataset.n_image):
        data = dataset.iterate_batch()  # data: id, img, scene...
        sent = beam_searcher.generate(data[1:])
        cap = ' '.join([dataset.vocab[word] for word in sent])
        print '[{}], id={}, \t\t {}'.format(i, data[0], cap)
        all_candidates.append({'image_id': data[0], 'caption': cap})
    toc = time.clock() - tic
    running_time = toc / 5000.0

    if res_file is None:
        res_file = 'tmp.json'
    json.dump(all_candidates, open(res_file, 'w'))
    gt_file = osp.join(dataset.data_dir, 'captions_'+dataset.data_split+'.json')
    scores = evaluate(gt_file, res_file, logger)
    if res_file == 'tmp.json':
        os.system('rm -rf %s' % res_file)

    return scores, running_time
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号