replay.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:nelpy 作者: nelpy 项目源码 文件源码
def score_hmm_events(bst, k_folds=None, num_states=30, n_shuffles=5000, shuffle='row-wise', verbose=False):
    """scores all sequences in the entire bst"""
    if k_folds is None:
        k_folds = 5

    if shuffle == 'row-wise':
        rowwise = True
    elif shuffle == 'col-wise':
        rowwise = False
    else:
        raise ValueError("tmat must be either 'row-wise' or 'col-wise'")

    X = [ii for ii in range(bst.n_epochs)]

    scores_hmm = np.zeros(bst.n_epochs)
    scores_hmm_shuffled = np.zeros((bst.n_epochs, n_shuffles))

    for kk, (training, validation) in enumerate(k_fold_cross_validation(X, k=k_folds)):
        if verbose:
            print('  fold {}/{}'.format(kk+1, k_folds))

        PBEs_train = bst[training]
        PBEs_test = bst[validation]

        # train HMM on all training PBEs
        hmm = PoissonHMM(n_components=num_states, random_state=0, verbose=False)
        hmm.fit(PBEs_train)

        # reorder states according to transmat ordering
        transmat_order = hmm.get_state_order('transmat')
        hmm.reorder_states(transmat_order)

        # compute scores_hmm (log likelihoods) of validation set:
        scores_hmm[validation] = hmm.score(PBEs_test)

        hmm_shuffled = copy.deepcopy(hmm)
        for nn in range(n_shuffles):
            # shuffle transition matrix:
            if rowwise:
                hmm_shuffled.transmat_ = shuffle_transmat(hmm_shuffled.transmat)
            else:
                hmm_shuffled.transmat_ = shuffle_transmat_Kourosh_breaks_stochasticity(hmm_shuffled.transmat)
                hmm_shuffled.transmat_ = hmm_shuffled.transmat / np.tile(hmm_shuffled.transmat.sum(axis=1), (hmm_shuffled.n_components, 1)).T

            # score validation set with shuffled HMM
            scores_hmm_shuffled[validation, nn] = hmm_shuffled.score(PBEs_test)

    n_scores = len(scores_hmm)
    scores_hmm_percentile = np.array([stats.percentileofscore(scores_hmm_shuffled[idx], scores_hmm[idx], kind='mean') for idx in range(n_scores)])

    return scores_hmm, scores_hmm_shuffled, scores_hmm_percentile
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号