predict.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:KagglePlanetPytorch 作者: Mctigger 项目源码 文件源码
def predict_kfold(model_name, pre_transforms=[]):
    model = locate(model_name + '.generate_model')()
    random_state = locate(model_name + '.random_state')
    print('Random state: {}'.format(random_state))

    labels_df = labels.get_labels_df()
    kf = sklearn.model_selection.KFold(n_splits=5, shuffle=True, random_state=random_state)
    split = kf.split(labels_df)

    for i, (train_idx, val_idx) in enumerate(split):
        split_name = model_name + '-split_' + str(i)
        best_epoch = util.find_epoch_val(split_name)
        print('Using epoch {} for predictions'.format(best_epoch))
        epoch_name = split_name + '-epoch_' + str(best_epoch)
        train = labels_df.ix[train_idx]
        val = labels_df.ix[val_idx]
        state = torch.load(os.path.join(paths.models, split_name, epoch_name))

        predict_model(model, state, train, val, output_file=split_name, pre_transforms=pre_transforms)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号