ops_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def testFullMatrix3d(self):
    np.random.seed(1)
    for rank in [1, 2]:
      a = np.random.rand(3, 2, 3, rank).astype(np.float32)
      b = np.random.rand(3, rank, 4, 5, rank).astype(np.float32)
      c = np.random.rand(3, rank, 2, 2).astype(np.float32)
      tt_cores = (a.reshape(3, 1, 2, 3, rank), b.reshape(3, rank, 4, 5, rank),
                  c.reshape(3, rank, 2, 2, 1))
      # Basically do full by hand.
      desired = np.einsum('oija,oaklb,obpq->oijklpq', a, b, c)
      desired = desired.reshape((3, 2, 3, 4, 5, 2, 2))
      desired = desired.transpose((0, 1, 3, 5, 2, 4, 6))
      desired = desired.reshape((3, 2 * 4 * 2, 3 * 5 * 2))
      with self.test_session():
        tf_mat = TensorTrainBatch(tt_cores)
        actual = ops.full(tf_mat)
        self.assertAllClose(desired, actual.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号