train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:planet-pytorch 作者: kefth 项目源码 文件源码
def validate(net, loader, criterion):
    net.eval()
    running_loss = 0
    running_accuracy = 0
    targets = torch.FloatTensor(0,17) # For fscore calculation
    predictions = torch.FloatTensor(0,17)
    for i, (X,y) in enumerate(loader):
        if cuda:
            X, y = X.cuda(), y.cuda()
        X, y = Variable(X, volatile=True), Variable(y)
        output = net(X)
        loss = criterion(output, y)
        acc = utils.get_multilabel_accuracy(output, y)
        targets = torch.cat((targets, y.cpu().data), 0)
        predictions = torch.cat((predictions,output.cpu().data), 0)
        running_loss += loss.data[0]
        running_accuracy += acc
    fscore = fbeta_score(targets.numpy(), predictions.numpy() > 0.23,
                beta=2, average='samples')
    return running_loss/len(loader), running_accuracy/len(loader), fscore
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号