def predict_kfold(model_name, pre_transforms=[]):
model = locate(model_name + '.generate_model')()
random_state = locate(model_name + '.random_state')
print('Random state: {}'.format(random_state))
labels_df = labels.get_labels_df()
kf = sklearn.model_selection.KFold(n_splits=5, shuffle=True, random_state=random_state)
split = kf.split(labels_df)
for i, (train_idx, val_idx) in enumerate(split):
split_name = model_name + '-split_' + str(i)
best_epoch = util.find_epoch_val(split_name)
print('Using epoch {} for predictions'.format(best_epoch))
epoch_name = split_name + '-epoch_' + str(best_epoch)
train = labels_df.ix[train_idx]
val = labels_df.ix[val_idx]
state = torch.load(os.path.join(paths.models, split_name, epoch_name))
predict_model(model, state, train, val, output_file=split_name, pre_transforms=pre_transforms)
评论列表
文章目录