def saga_decision_function(dataset, k, link_alpha, prop_alpha, l1_ratio):
fn = cache_fname("linear_val_df", (dataset, k, link_alpha, prop_alpha,
l1_ratio))
if os.path.exists(fn):
logging.info("Loading {}".format(fn))
with open(fn, "rb") as f:
return dill.load(f)
ds = 'erule' if dataset == 'cdcp' else 'ukp-essays' # sorry
path = os.path.join("data", "process", ds, "folds", "{}", "{}")
# sorry again: get val docs
n_folds = 5 if dataset == 'ukp' else 3
load, ids = get_dataset_loader(dataset, "train")
for k_, (_, val) in enumerate(KFold(n_folds).split(ids)):
if k_ == k:
break
val_docs = list(load(ids[val]))
X_tr_link, y_tr_link = load_csr(path.format(k, 'train.npz'),
return_y=True)
X_te_link, y_te_link = load_csr(path.format(k, 'val.npz'),
return_y=True)
X_tr_prop, y_tr_prop = load_csr(path.format(k, 'prop-train.npz'),
return_y=True)
X_te_prop, y_te_prop = load_csr(path.format(k, 'prop-val.npz'),
return_y=True)
baseline = BaselineStruct(link_alpha, prop_alpha, l1_ratio)
baseline.fit(X_tr_link, y_tr_link, X_tr_prop, y_tr_prop)
Y_marg = baseline.decision_function(X_te_link, X_te_prop, val_docs)
with open(fn, "wb") as f:
logging.info("Saving {}".format(fn))
dill.dump((Y_marg, baseline), f)
return Y_marg, baseline
评论列表
文章目录