def validate(beam_searcher, dataset, logger=None, res_file=None):
if logger is None:
logger = Logger(None)
# generating captions
all_candidates = []
tic = time.clock()
for i in xrange(dataset.n_image):
data = dataset.iterate_batch() # data: id, img, scene...
sent = beam_searcher.generate(data[1:])
cap = ' '.join([dataset.vocab[word] for word in sent])
print '[{}], id={}, \t\t {}'.format(i, data[0], cap)
all_candidates.append({'image_id': data[0], 'caption': cap})
toc = time.clock() - tic
running_time = toc / 5000.0
if res_file is None:
res_file = 'tmp.json'
json.dump(all_candidates, open(res_file, 'w'))
gt_file = osp.join(dataset.data_dir, 'captions_'+dataset.data_split+'.json')
scores = evaluate(gt_file, res_file, logger)
if res_file == 'tmp.json':
os.system('rm -rf %s' % res_file)
return scores, running_time
评论列表
文章目录