def test_with_fixed_inputs(self):
inputs = tf.random_normal(
[self.batch_size, self.sequence_length, self.input_depth])
seq_length = tf.ones(self.batch_size, dtype=tf.int32) * self.sequence_length
helper = decode_helper.TrainingHelper(
inputs=inputs, sequence_length=seq_length)
decoder_fn = self.create_decoder(
helper=helper, mode=tf.contrib.learn.ModeKeys.TRAIN)
initial_state = decoder_fn.cell.zero_state(
self.batch_size, dtype=tf.float32)
decoder_output, _ = decoder_fn(initial_state, helper)
#pylint: disable=E1101
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
decoder_output_ = sess.run(decoder_output)
np.testing.assert_array_equal(
decoder_output_.logits.shape,
[self.sequence_length, self.batch_size, self.vocab_size])
np.testing.assert_array_equal(decoder_output_.predicted_ids.shape,
[self.sequence_length, self.batch_size])
return decoder_output_
评论列表
文章目录