input_pipeline.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:nmt_v2 作者: rpryzant 项目源码 文件源码
def get_test_iterator(src_dataset, src_vocab_table, batch_size, config):
    src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(config.eos)), tf.int32)
    src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)

    src_dataset = src_dataset.map(lambda src: src[:config.src_max_len])

    src_dataset = src_dataset.map(
        lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32))

    if config.reverse_src:
        src_dataset = src_dataset.map(lambda src: tf.reverse(src, axis=[0]))

    src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))

    def batching_func(x):
        return x.padded_batch(
            config.batch_size,
            padded_shapes=(tf.TensorShape([None]),
                           tf.TensorShape([])),
            padding_values=(src_eos_id,
                            0))

    batched_dataset = batching_func(src_dataset)
    batched_iter = batched_dataset.make_initializable_iterator()
    src_ids, src_seq_len = batched_iter.get_next()
    return BatchedInput(
        initializer=batched_iter.initializer,
        source=src_ids,
        target_input=None,
        target_output=None,
        source_sequence_length=src_seq_len,
        target_sequence_length=None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号