def test(epoch):
Net.eval()
preditions=[]
for iteration,(inputss,labelss) in enumerate(testing_data_loader,1):
#embed()
inputss=Variable(inputss.view(-1,3,training_size,training_size))
if cuda:
inputss=inputss.cuda()
#embed()
if model=='resnet101':
prediction=Net.module.resnet101(inputss).cpu().data.numpy()
elif model=='inception_v3':
prediction=Net.module.inception_v3(inputss).cpu().data.numpy()
elif model=='inception_v4':
prediction=Net.module.inception_v4(inputss).cpu().data.numpy()
prediction=prediction.mean(0).argmax()
preditions.append(str(prediction))
#print 'video num: ',iteration,' predition: ',str(prediction)
with open('/S2/MI/zqj/video_classification/data/ucf101/tmp_result/{}result_{}_new_epoch'.format(save_prefix,model)+str(epoch)+'.txt','w')as fw:
fw.write('\n'.join(preditions))
str_out='python compute_test_result.py {}result_{}_new_epoch'.format(save_prefix,model)+str(epoch)+'.txt'
os.system(str_out)
main.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录