record.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:XMUNMT 作者: XMUNLP 项目源码 文件源码
def input_pipeline(file_pattern, mode, capacity=64):
    keys_to_features = {
        "inputs": tf.VarLenFeature(tf.int64),
        "targets": tf.VarLenFeature(tf.int64)
    }

    items_to_handlers = {
        "inputs": tfexample_decoder.Tensor("inputs"),
        "targets": tfexample_decoder.Tensor("targets")
    }

    # Now the non-trivial case construction.
    with tf.name_scope("examples_queue"):
        training = (mode == "train")
        # Read serialized examples using slim parallel_reader.
        num_epochs = None if training else 1
        data_files = parallel_reader.get_data_files(file_pattern)
        num_readers = min(4 if training else 1, len(data_files))
        _, examples = parallel_reader.parallel_read([file_pattern],
                                                    tf.TFRecordReader,
                                                    num_epochs=num_epochs,
                                                    shuffle=training,
                                                    capacity=2 * capacity,
                                                    min_after_dequeue=capacity,
                                                    num_readers=num_readers)

        decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                     items_to_handlers)

        decoded = decoder.decode(examples, items=list(items_to_handlers))
        examples = {}

        for (field, tensor) in zip(keys_to_features, decoded):
            examples[field] = tensor

        # We do not want int64s as they do are not supported on GPUs.
        return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号