find_best_threthold.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:kaggle-planet 作者: ZijunDeng 项目源码 文件源码
def main():
    training_batch_size = 352
    validation_batch_size = 352

    net = get_res152(num_classes=num_classes, snapshot_path=os.path.join(
        ckpt_path, 'epoch_15_validation_loss_0.0772_iter_1000.pth')).cuda()
    net.eval()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.311, 0.340, 0.299], [0.167, 0.144, 0.138])
    ])
    criterion = nn.MultiLabelSoftMarginLoss().cuda()

    train_set = MultipleClassImageFolder(split_train_dir, transform)
    train_loader = DataLoader(train_set, batch_size=training_batch_size, num_workers=16)
    batch_outputs, batch_labels = predict(net, train_loader)
    loss = criterion(batch_outputs, batch_labels)
    print 'training loss %.4f' % loss.cpu().data.numpy()[0]
    batch_outputs = batch_outputs.cpu().data.numpy()
    batch_labels = batch_labels.cpu().data.numpy()
    thretholds = find_best_threthold(batch_outputs, batch_labels)

    val_set = MultipleClassImageFolder(split_val_dir, transform)
    val_loader = DataLoader(val_set, batch_size=validation_batch_size, num_workers=16)
    batch_outputs, batch_labels = predict(net, val_loader)
    loss = criterion(batch_outputs, batch_labels)
    print 'validation loss %.4f' % loss.cpu().data.numpy()[0]
    batch_outputs = batch_outputs.cpu().data.numpy()
    batch_labels = batch_labels.cpu().data.numpy()
    sio.savemat('./val_output.mat', {'outputs': batch_outputs, 'labels': batch_labels})
    prediction = get_one_hot_prediction(batch_outputs, thretholds)
    evaluation = evaluate(prediction, batch_labels)
    print 'validation evaluation: accuracy %.4f, precision %.4f, recall %.4f, f2 %.4f' % (
        evaluation[0], evaluation[1], evaluation[2], evaluation[3])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号