def get_train_test_fold_filenames(true_iob_dir, use_pickle=True):
pickle_fname = '_train_test_fold_fnames.pkl'
if use_pickle:
try:
return pickle.load(open(pickle_fname, 'rb'))
except IOError:
pass
# Misuse data collecting function to get X, y and filenames.
# Since we are not interested in the actual features, we pretend true_iob_dir is a feature dir.
data = collect_crf_data(true_iob_dir, true_iob_dir)
# Now create
group_k_fold = GroupKFold(n_splits=5)
# Create folds from complete texts only (i.e. instances of the same text are never in different folds)
# Use same split for all three entities.
# Note that there is no random seed, because the output of group_k_fold.split is deterministic
# as long as the iob files are globbed in exactly the same order
splits = group_k_fold.split(data['feats'], data['Material'], data['filenames'])
fnames = np.array(data['filenames'])
train_test_fold_fnames = []
for train_idx, test_idx in splits:
train_fnames = np.unique(fnames[train_idx])
test_fnames = np.unique(fnames[test_idx])
train_test_fold_fnames.append((train_fnames, test_fnames))
pickle.dump(train_test_fold_fnames, open(pickle_fname, 'wb'))
return train_test_fold_fnames
评论列表
文章目录