ops_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:t3f 作者: Bihaqo 项目源码 文件源码
def testFlatInnerTTTensbyTTTensBroadcasting(self):
    # Inner product between two batch TT-tensors with broadcasting.
    tt_1 = initializers.random_tensor_batch((2, 3, 4), batch_size=1)
    tt_2 = initializers.random_tensor_batch((2, 3, 4), batch_size=3)
    res_actual_1 = ops.flat_inner(tt_1, tt_2)
    res_actual_2 = ops.flat_inner(tt_2, tt_1)
    res_desired = tf.einsum('ijk,oijk->o', ops.full(tt_1[0]), ops.full(tt_2))
    with self.test_session() as sess:
      res = sess.run([res_actual_1, res_actual_2, res_desired])
      res_actual_1_val, res_actual_2_val, res_desired_val = res
      self.assertAllClose(res_actual_1_val, res_desired_val)
      self.assertAllClose(res_actual_2_val, res_desired_val)

    tt_1 = initializers.random_tensor_batch((2, 3, 4), batch_size=2)
    with self.assertRaises(ValueError):
      # The batch_sizes are different.
      ops.flat_inner(tt_1, tt_2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号