train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:yaset 作者: jtourille 项目源码 文件源码
def _build_train_pipeline(tfrecords_file_path, feature_columns, buckets=None, batch_size=None,
                          nb_instances=None):
    """
    Build the train pipeline. Sequences are grouped into buckets for faster training.
    :param tfrecords_file_path: train TFRecords file path
    :param buckets: train buckets
    :param batch_size: mini-batch size
    :return: queue runner list, queues, symbolic link to mini-batch
    """

    with tf.device('/cpu:0'):

        # Creating a list with tfrecords
        tfrecords_list = [tfrecords_file_path]

        # Will contains queue runners for thread creation
        queue_runner_list = list()

        # Filename queue, contains only on filename (train TFRecords file)
        filename_queue = tf.train.string_input_producer(tfrecords_list)

        # Decode one example
        tensor_list = read_and_decode(filename_queue, feature_columns)

        dtypes = [tf.string, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32]
        for _ in feature_columns:
            dtypes.append(tf.int32)

        # Random shuffle queue, allow for randomization of training instances (maximum size: 50% of nb. instances)
        shuffle_queue = tf.RandomShuffleQueue(nb_instances, nb_instances//2, dtypes=dtypes)

        # Enqueue and dequeue Ops + queue runner creation
        enqueue_op_shuffle_queue = shuffle_queue.enqueue(tensor_list)
        inputs = shuffle_queue.dequeue()

        queue_runner_list.append(tf.train.QueueRunner(shuffle_queue, [enqueue_op_shuffle_queue] * 4))

        shapes = [[], [], [None], [None, None], [None], [None]]
        for _ in feature_columns:
            shapes.append([None])

        if buckets:
            # Bucketing according to bucket boundaries passed as arguments
            length, batch = tf.contrib.training.bucket_by_sequence_length(inputs[1], inputs, batch_size,
                                                                          sorted(buckets),
                                                                          num_threads=4,
                                                                          capacity=32,
                                                                          shapes=shapes,
                                                                          dynamic_pad=True)
        else:

            padding_queue = tf.PaddingFIFOQueue(nb_instances, dtypes=dtypes, shapes=shapes)
            enqueue_op_padding_queue = padding_queue.enqueue(inputs)
            batch = padding_queue.dequeue_many(batch_size)

            queue_runner_list.append(tf.train.QueueRunner(padding_queue, [enqueue_op_padding_queue] * 4))

        return queue_runner_list, [filename_queue, shuffle_queue], batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号