multiuser.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:musm-adt17 作者: stefanoteso 项目源码 文件源码
def crossvalidate(problem, dataset, set_size, uid, w, var, cov,
                  transform, old_alpha, lmbda=0.5):
    """Finds the best hyperparameters using cross-validation.

    Parameters
    ----------
    WRITEME

    Returns
    -------
    alpha : tuple
        The best hyperparameter.
    """

    if len(dataset) % _NUM_FOLDS != 0:
        return old_alpha

    kfold = KFold(len(dataset), n_folds=_NUM_FOLDS)
    f = compute_transform(uid, w, var, cov, transform, lmbda=lmbda)

    avg_accuracy = np.zeros(len(_ALPHAS))
    for i, alpha in enumerate(_ALPHAS):
        accuracies = []
        for tr_indices, ts_indices in kfold:
            w, _ = problem.select_query(dataset[tr_indices], set_size, alpha,
                                        transform=f)
            utilities = np.dot(w, dataset[ts_indices].T)
            accuracies.append((utilities > 0).mean())
        avg_accuracy[i] = sum(accuracies) / len(accuracies)

    alpha = _I_TO_ALPHA[np.argmax(avg_accuracy)]

    _LOG.debug('''\
            alpha accuracies = {avg_accuracy}
            best alpha = {alpha}
        ''', **locals())

    return alpha
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号