def validate(val_loader, net, criterion):
net.eval()
batch_outputs = []
batch_labels = []
for vi, data in enumerate(val_loader, 0):
inputs, labels = data
inputs = Variable(inputs, volatile=True).cuda()
labels = Variable(labels.float(), volatile=True).cuda()
outputs = net(inputs)
batch_outputs.append(outputs)
batch_labels.append(labels)
batch_outputs = torch.cat(batch_outputs)
batch_labels = torch.cat(batch_labels)
val_loss = criterion(batch_outputs, batch_labels)
val_loss = val_loss.data[0]
print '--------------------------------------------------------'
print '[val_loss %.4f]' % val_loss
net.train()
return val_loss
评论列表
文章目录