inputs_fn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:attention 作者: louishenrifranc 项目源码 文件源码
def get_input_fn(batch_size, num_epochs, context_filename, answer_filename, max_sequence_len):
    def input_fn():
        source_dataset = tf.contrib.data.TextLineDataset(context_filename)
        target_dataset = tf.contrib.data.TextLineDataset(answer_filename)

        def map_dataset(dataset):
            dataset = dataset.map(lambda string: tf.string_split([string]).values)
            dataset = dataset.map(lambda token: tf.string_to_number(token, tf.int64))
            dataset = dataset.map(lambda tokens: (tokens, tf.size(tokens)))
            dataset = dataset.map(lambda tokens, size: (tokens[:max_sequence_len], tf.minimum(size, max_sequence_len)))
            return dataset

        source_dataset = map_dataset(source_dataset)
        target_dataset = map_dataset(target_dataset)

        dataset = tf.contrib.data.Dataset.zip((source_dataset, target_dataset))
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.padded_batch(batch_size,
                                       padded_shapes=((tf.TensorShape([max_sequence_len]), tf.TensorShape([])),
                                                      (tf.TensorShape([max_sequence_len]), tf.TensorShape([]))
                                                      ))

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        return next_element, None

    return input_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号