def testFlatInnerTTTensbyTTTensBroadcasting(self):
# Inner product between two batch TT-tensors with broadcasting.
tt_1 = initializers.random_tensor_batch((2, 3, 4), batch_size=1)
tt_2 = initializers.random_tensor_batch((2, 3, 4), batch_size=3)
res_actual_1 = ops.flat_inner(tt_1, tt_2)
res_actual_2 = ops.flat_inner(tt_2, tt_1)
res_desired = tf.einsum('ijk,oijk->o', ops.full(tt_1[0]), ops.full(tt_2))
with self.test_session() as sess:
res = sess.run([res_actual_1, res_actual_2, res_desired])
res_actual_1_val, res_actual_2_val, res_desired_val = res
self.assertAllClose(res_actual_1_val, res_desired_val)
self.assertAllClose(res_actual_2_val, res_desired_val)
tt_1 = initializers.random_tensor_batch((2, 3, 4), batch_size=2)
with self.assertRaises(ValueError):
# The batch_sizes are different.
ops.flat_inner(tt_1, tt_2)
评论列表
文章目录