test_base.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tensorly 作者: tensorly 项目源码 文件源码
def test_cp_tensor():
    """test for random.cp_tensor"""
    shape = (10, 11, 12)
    rank = 4

    tensor = cp_tensor(shape, rank, full=True)
    for i in range(T.ndim(tensor)):
        T.assert_equal(matrix_rank(T.to_numpy(unfold(tensor, i))), rank)

    factors = cp_tensor(shape, rank, full=False)
    for i, factor in enumerate(factors):
        T.assert_equal(factor.shape, (shape[i], rank),
                err_msg=('{}-th factor has shape {}, expected {}'.format(
                     i, factor.shape, (shape[i], rank))))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号