retain.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:retain 作者: mp2893 项目源码 文件源码
def calculate_cost(test_model, dataset, options):
    batchSize = options['batchSize']
    useTime = options['useTime']

    costSum = 0.0
    dataCount = 0

    n_batches = int(np.ceil(float(len(dataset[0])) / float(batchSize)))
    for index in xrange(n_batches):
        batchX = dataset[0][index*batchSize:(index+1)*batchSize]
        if useTime:
            batchT = dataset[2][index*batchSize:(index+1)*batchSize]
            x, t, lengths = padMatrixWithTime(batchX, batchT, options)
            y = np.array(dataset[1][index*batchSize:(index+1)*batchSize]).astype(config.floatX)
            scores = test_model(x, y, t, lengths)
        else:
            x, lengths = padMatrixWithoutTime(batchX, options)
            y = np.array(dataset[1][index*batchSize:(index+1)*batchSize]).astype(config.floatX)
            scores = test_model(x, y, lengths)
        costSum += scores * len(batchX)
        dataCount += len(batchX)
    return costSum / dataCount
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号