retain.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:retain 作者: mp2893 项目源码 文件源码
def calculate_auc(test_model, dataset, options):
    batchSize = options['batchSize']
    useTime = options['useTime']

    n_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize)))
    scoreVec = []
    for index in xrange(n_batches):
        batchX = dataset[0][index*batchSize:(index+1)*batchSize]
        if useTime:
            batchT = dataset[2][index*batchSize:(index+1)*batchSize]
            x, t, lengths = padMatrixWithTime(batchX, batchT, options)
            scores = test_model(x, t, lengths)
        else:
            x, lengths = padMatrixWithoutTime(batchX, options)
            scores = test_model(x, lengths)
        scoreVec.extend(list(scores))
    labels = dataset[1]
    auc = roc_auc_score(list(labels), list(scoreVec))
    return auc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号