def get_infer_iterator(src_dataset,
src_vocab_table,
batch_size,
eos,
src_max_len=None):
src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)
if src_max_len:
src_dataset = src_dataset.map(lambda src: src[:src_max_len])
# Convert the word strings to ids
src_dataset = src_dataset.map(
lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32))
# Add in the word counts.
src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))
def batching_func(x):
return x.padded_batch(
batch_size,
# The entry is the source line rows;
# this has unknown-length vectors. The last entry is
# the source row size; this is a scalar.
padded_shapes=(
tf.TensorShape([None]), # src
tf.TensorShape([])), # src_len
# Pad the source sequences with eos tokens.
# (Though notice we don't generally need to do this since
# later on we will be masking out calculations past the true sequence.
padding_values=(
src_eos_id, # src
0)) # src_len -- unused
batched_dataset = batching_func(src_dataset)
batched_iter = batched_dataset.make_initializable_iterator()
(src_ids, src_seq_len) = batched_iter.get_next()
return BatchedInput(
initializer=batched_iter.initializer,
source=src_ids,
target_input=None,
target_output=None,
source_sequence_length=src_seq_len,
target_sequence_length=None)
评论列表
文章目录