basic_test.py 文件源码

python
阅读 44 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testInputExampleIndex(self):
    in1 = tf.random_normal((3, 5))
    in2 = tf.random_normal((3, 9))

    def build(inputs):
      a, b = inputs
      a.get_shape().assert_is_compatible_with([3 * 5])
      b.get_shape().assert_is_compatible_with([3 * 9])
      return b

    op = snt.Module(build)

    # Checks an error is thrown when the input example contains a different
    # shape for the leading dimensions as the output.
    with self.assertRaises(ValueError):
      snt.BatchApply(op, n_dims=2, input_example_index=0)((in1, in2))

    # Check correct operation when the specified input example contains the same
    # shape for the leading dimensions as the output.
    output = snt.BatchApply(op, n_dims=2, input_example_index=1)((in1, in2))
    with self.test_session() as sess:
      in2_np, out_np = sess.run([in2, output])
      self.assertAllEqual(in2_np, out_np)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号