DataFeeder.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:TFCommon 作者: MU94W 项目源码 文件源码
def __init__(self, coordinator, placeholders, meta, batch_size=32, split_nums=None, is_validation=False):
        """

        :param coordinator: 
        :param placeholders:
        :param meta: 
        :param batch_size: 
        :param split_nums: 
        :param is_validation: 
        """
        super(BaseFeeder, self).__init__()
        queue = tf.FIFOQueue(capacity=math.ceil(batch_size/4), dtypes=[item.dtype for item in placeholders])
        self.queue = queue  # for buf inspect
        self.enqueue_op = queue.enqueue(placeholders)
        self.fed_holders = [None] * len(placeholders)   # None placeholder for dequeue
        self.fed_holders = queue.dequeue()
        for idx in range(len(placeholders)):
            self.fed_holders[idx].set_shape(placeholders[idx].shape)
        self._placeholders = placeholders
        self.coord = coordinator
        self.sess = None
        self.meta = meta
        key_lst = meta.get('key_lst')
        assert isinstance(key_lst, list) or isinstance(key_lst, tuple)
        self.key_lst = key_lst
        self.batch_size = batch_size
        self.split_bool = False if split_nums is None else True
        self.split_nums = split_nums
        assert isinstance(is_validation, bool)
        self.is_validation = is_validation
        self._total_samples = len(key_lst)
        self._iter = 0
        self._record_index = 0
        self._loss = 0.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号