trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ternarynet 作者: czhu95 项目源码 文件源码
def __init__(self, config, input_queue=None, predict_tower=None):
        """
        :param config: a `TrainConfig` instance
        :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
            Defaults to a FIFO queue of size 100.
        :param predict_tower: list of gpu relative idx to run prediction. default to be [0].
            Use -1 for cpu.
        """
        super(QueueInputTrainer, self).__init__(config)
        self.input_vars = self.model.get_input_vars()
        # use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
        if input_queue is None:
            self.input_queue = tf.FIFOQueue(
                    50, [x.dtype for x in self.input_vars], name='input_queue')
        else:
            self.input_queue = input_queue

        # by default, use the first training gpu for prediction
        self.predict_tower = predict_tower or [0]
        self.dequed_inputs = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号