def get_input_fn(batch_size, num_epochs, context_filename, answer_filename, max_sequence_len):
def input_fn():
source_dataset = tf.contrib.data.TextLineDataset(context_filename)
target_dataset = tf.contrib.data.TextLineDataset(answer_filename)
def map_dataset(dataset):
dataset = dataset.map(lambda string: tf.string_split([string]).values)
dataset = dataset.map(lambda token: tf.string_to_number(token, tf.int64))
dataset = dataset.map(lambda tokens: (tokens, tf.size(tokens)))
dataset = dataset.map(lambda tokens, size: (tokens[:max_sequence_len], tf.minimum(size, max_sequence_len)))
return dataset
source_dataset = map_dataset(source_dataset)
target_dataset = map_dataset(target_dataset)
dataset = tf.contrib.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.repeat(num_epochs)
dataset = dataset.padded_batch(batch_size,
padded_shapes=((tf.TensorShape([max_sequence_len]), tf.TensorShape([])),
(tf.TensorShape([max_sequence_len]), tf.TensorShape([]))
))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
return next_element, None
return input_fn
评论列表
文章目录