def testComputation(self):
np.random.seed(100)
in_shape = [2, 3, 4]
in_shape_flat = [6, 4]
hidden_size = 5
out_shape1 = in_shape[:2] + [hidden_size]
out_shape2 = in_shape
inputs = tf.random_uniform(shape=in_shape)
inputs_flat = tf.reshape(inputs, shape=in_shape_flat)
linear = snt.Linear(hidden_size,
initializers={"w": _test_initializer(),
"b": _test_initializer()})
merge_linear = snt.BatchApply(module_or_op=linear)
outputs1 = merge_linear(inputs)
outputs1_flat = linear(inputs_flat)
merge_tanh = snt.BatchApply(module_or_op=tf.tanh)
outputs2 = merge_tanh(inputs)
outputs2_flat = merge_tanh(inputs_flat)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out1, out_flat1 = sess.run([outputs1, outputs1_flat])
out2, out_flat2 = sess.run([outputs2, outputs2_flat])
self.assertAllClose(out1, out_flat1.reshape(out_shape1))
self.assertAllClose(out2, out_flat2.reshape(out_shape2))
评论列表
文章目录