forward.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:decoding_challenge_cortana_2016_3rd 作者: kingjr 项目源码 文件源码
def compute_depth_prior(G, gain_info, is_fixed_ori, exp=0.8, limit=10.0,
                        patch_areas=None, limit_depth_chs=False):
    """Compute weighting for depth prior
    """
    logger.info('Creating the depth weighting matrix...')

    # If possible, pick best depth-weighting channels
    if limit_depth_chs is True:
        G = _restrict_gain_matrix(G, gain_info)

    # Compute the gain matrix
    if is_fixed_ori:
        d = np.sum(G ** 2, axis=0)
    else:
        n_pos = G.shape[1] // 3
        d = np.zeros(n_pos)
        for k in range(n_pos):
            Gk = G[:, 3 * k:3 * (k + 1)]
            d[k] = linalg.svdvals(np.dot(Gk.T, Gk))[0]

    # XXX Currently the fwd solns never have "patch_areas" defined
    if patch_areas is not None:
        d /= patch_areas ** 2
        logger.info('    Patch areas taken into account in the depth '
                    'weighting')

    w = 1.0 / d
    ws = np.sort(w)
    weight_limit = limit ** 2
    if limit_depth_chs is False:
        # match old mne-python behavor
        ind = np.argmin(ws)
        n_limit = ind
        limit = ws[ind] * weight_limit
        wpp = (np.minimum(w / limit, 1)) ** exp
    else:
        # match C code behavior
        limit = ws[-1]
        n_limit = len(d)
        if ws[-1] > weight_limit * ws[0]:
            ind = np.where(ws > weight_limit * ws[0])[0][0]
            limit = ws[ind]
            n_limit = ind

    logger.info('    limit = %d/%d = %f'
                % (n_limit + 1, len(d),
                   np.sqrt(limit / ws[0])))
    scale = 1.0 / limit
    logger.info('    scale = %g exp = %g' % (scale, exp))
    wpp = np.minimum(w / limit, 1) ** exp

    depth_prior = wpp if is_fixed_ori else np.repeat(wpp, 3)

    return depth_prior
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号