def testFullMatrix3d(self):
np.random.seed(1)
for rank in [1, 2]:
a = np.random.rand(3, 2, 3, rank).astype(np.float32)
b = np.random.rand(3, rank, 4, 5, rank).astype(np.float32)
c = np.random.rand(3, rank, 2, 2).astype(np.float32)
tt_cores = (a.reshape(3, 1, 2, 3, rank), b.reshape(3, rank, 4, 5, rank),
c.reshape(3, rank, 2, 2, 1))
# Basically do full by hand.
desired = np.einsum('oija,oaklb,obpq->oijklpq', a, b, c)
desired = desired.reshape((3, 2, 3, 4, 5, 2, 2))
desired = desired.transpose((0, 1, 3, 5, 2, 4, 6))
desired = desired.reshape((3, 2 * 4 * 2, 3 * 5 * 2))
with self.test_session():
tf_mat = TensorTrainBatch(tt_cores)
actual = ops.full(tf_mat)
self.assertAllClose(desired, actual.eval())
评论列表
文章目录