test_lower_triangular_matrix.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:chainerrl 作者: chainer 项目源码 文件源码
def check_forward(self, diag_data, non_diag_data):
        diag = chainer.Variable(diag_data)
        non_diag = chainer.Variable(non_diag_data)
        y = lower_triangular_matrix(diag, non_diag)

        correct_y = numpy.zeros(
            (self.batch_size, self.n, self.n), dtype=numpy.float32)

        tril_rows, tril_cols = numpy.tril_indices(self.n, -1)
        correct_y[:, tril_rows, tril_cols] = cuda.to_cpu(non_diag_data)

        diag_rows, diag_cols = numpy.diag_indices(self.n)
        correct_y[:, diag_rows, diag_cols] = cuda.to_cpu(diag_data)

        gradient_check.assert_allclose(correct_y, cuda.to_cpu(y.data))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号