error.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainer-glu 作者: musyoku 项目源码 文件源码
def compute_perplexity(model, buckets, batchsize=100):
    result = []
    for bucket_index, dataset in enumerate(buckets):
        ppl = []
        # split into minibatch
        if len(dataset) > batchsize:
            num_sections = len(dataset) // batchsize - 1
            if len(dataset) % batchsize > 0:
                num_sections += 1
            indices = [(i + 1) * batchsize for i in xrange(num_sections)]
            sections = np.split(dataset, indices, axis=0)
        else:
            sections = [dataset]
        # compute accuracy
        for batch_index, batch in enumerate(sections):
            sys.stdout.write("\rcomputing perplexity ... bucket {}/{} (batch {}/{})".format(bucket_index + 1, len(buckets), batch_index + 1, len(sections)))
            sys.stdout.flush()
            ppl.append(compute_perplexity_batch(model, batch))

        result.append(sum(ppl) / len(ppl))
        sys.stdout.write("\r" + stdout.CLEAR)
        sys.stdout.flush()
    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号