def testInitializers(self, use_bias):
"""Test initializers work as expected."""
w = random.random()
b = random.random()
conv1 = snt.Conv1D(
output_channels=1,
kernel_shape=3,
stride=1,
padding=snt.SAME,
use_bias=use_bias,
name="conv1",
initializers=create_constant_initializers(w, b, use_bias))
conv1(tf.placeholder(tf.float32, [1, 10, 2]))
with self.test_session():
tf.variables_initializer(
[conv1.w, conv1.b] if use_bias else [conv1.w]).run()
self.assertAllClose(
conv1.w.eval(),
np.full([3, 2, 1], w, dtype=np.float32))
if use_bias:
self.assertAllClose(
conv1.b.eval(),
[b])
err = "Initializer for 'w' is not a callable function or dictionary"
with self.assertRaisesRegexp(TypeError, err):
snt.Conv1D(output_channels=10,
kernel_shape=3,
stride=1,
padding=snt.SAME,
use_bias=use_bias,
name="conv1",
initializers={"w": tf.ones([])})
评论列表
文章目录