basic_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testComputation(self):
    np.random.seed(100)
    in_shape = [2, 3, 4]
    in_shape_flat = [6, 4]
    hidden_size = 5
    out_shape1 = in_shape[:2] + [hidden_size]
    out_shape2 = in_shape
    inputs = tf.random_uniform(shape=in_shape)
    inputs_flat = tf.reshape(inputs, shape=in_shape_flat)
    linear = snt.Linear(hidden_size,
                        initializers={"w": _test_initializer(),
                                      "b": _test_initializer()})
    merge_linear = snt.BatchApply(module_or_op=linear)
    outputs1 = merge_linear(inputs)
    outputs1_flat = linear(inputs_flat)
    merge_tanh = snt.BatchApply(module_or_op=tf.tanh)
    outputs2 = merge_tanh(inputs)
    outputs2_flat = merge_tanh(inputs_flat)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      out1, out_flat1 = sess.run([outputs1, outputs1_flat])
      out2, out_flat2 = sess.run([outputs2, outputs2_flat])
      self.assertAllClose(out1, out_flat1.reshape(out_shape1))
      self.assertAllClose(out2, out_flat2.reshape(out_shape2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号