def testGetInferIterator(self):
src_vocab_table = lookup_ops.index_table_from_tensor(
tf.constant(["a", "b", "c", "eos", "sos"]))
src_dataset = tf.data.Dataset.from_tensor_slices(
tf.constant(["c c a", "c a", "d", "f e a g"]))
hparams = tf.contrib.training.HParams(
random_seed=3,
eos="eos",
sos="sos")
batch_size = 2
src_max_len = 3
iterator = iterator_utils.get_infer_iterator(
src_dataset=src_dataset,
src_vocab_table=src_vocab_table,
batch_size=batch_size,
eos=hparams.eos,
src_max_len=src_max_len)
table_initializer = tf.tables_initializer()
source = iterator.source
seq_len = iterator.source_sequence_length
self.assertEqual([None, None], source.shape.as_list())
self.assertEqual([None], seq_len.shape.as_list())
with self.test_session() as sess:
sess.run(table_initializer)
sess.run(iterator.initializer)
(source_v, seq_len_v) = sess.run((source, seq_len))
self.assertAllEqual(
[[2, 2, 0], # c c a
[2, 0, 3]], # c a eos
source_v)
self.assertAllEqual([3, 2], seq_len_v)
(source_v, seq_len_v) = sess.run((source, seq_len))
self.assertAllEqual(
[[-1, 3, 3], # "d" == unknown, eos eos
[-1, -1, 0]], # "f" == unknown, "e" == unknown, a
source_v)
self.assertAllEqual([1, 3], seq_len_v)
with self.assertRaisesOpError("End of sequence"):
sess.run((source, seq_len))
评论列表
文章目录