def testShapesNotKnown(self, use_bias):
"""The generated shapes are correct when input shape not known."""
batch_size = 5
in_length = 32
in_channels = out_channels = 5
kernel_shape = 3
inputs = tf.placeholder(
tf.float32,
shape=[None, None, in_channels],
name="inputs")
conv1 = snt.Conv1D(
name="conv1",
output_channels=out_channels,
kernel_shape=kernel_shape,
padding=snt.SAME,
stride=1,
use_bias=use_bias)
output = conv1(inputs)
with self.test_session():
tf.variables_initializer(
[conv1.w, conv1.b] if use_bias else [conv1.w]).run()
output_eval = output.eval({
inputs: np.zeros([batch_size, in_length, in_channels])})
self.assertEqual(
output_eval.shape,
(batch_size, in_length, out_channels))
评论列表
文章目录