keras_utils.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:AutoSleepScorerDev 作者: skjerns 项目源码 文件源码
def test_data_cnn_rnn(data, target, groups, cnn, rnn, layername='fc1', cropsize=2800, verbose=1, only_lstm = False):
    """
    mode = 'scores' or 'preds'
    take two ready trained models (cnn+rnn)
    test on input data and return acc+f1
    """
    if target.ndim==2: target = np.argmax(target,1)
    if cropsize != 0: 
        diff = (data.shape[1] - cropsize)//2
        data = data[:,diff:-diff:,:]

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        if only_lstm == False:
            cnn_pred = cnn.predict_classes(data, 1024,verbose=0)
        else:
            cnn_pred = target
        features = get_activations(cnn, data, 'fc1', verbose=verbose)

        cnn_acc = accuracy_score(target, cnn_pred)
        cnn_f1  = f1_score(target, cnn_pred, average='macro')

        seqlen = rnn.input_shape[1]
        features_seq, target_seq, groups_seq = tools.to_sequences(features, target, seqlen=seqlen, groups=groups)
        new_targ_seq = np.roll(target_seq, 4)
        rnn_pred = rnn.predict_classes(features_seq, 1024, verbose=0)
        rnn_acc = accuracy_score(new_targ_seq, rnn_pred)
        rnn_f1  = f1_score(new_targ_seq,rnn_pred, average='macro')
        confmat = confusion_matrix(new_targ_seq, rnn_pred)

    return [cnn_acc, cnn_f1, rnn_acc, rnn_f1, confmat, (rnn_pred, target_seq, groups_seq)]




#%%
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号