batchloader.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码
def __init__(self, params):

    self.batch_size = params['batch_size']
    self.outshape = params['shape']

    self.lmdb = lmdbs(params['source'])
    self.labels = self.lmdb.get_label_list()
    self.img_mean = biproto2py(params['mean_file']).squeeze()

    self.NIMGS = len(self.labels)

    self.num_batches = int(np.ceil(self.NIMGS/float(self.batch_size)))
    self._cur = 0  # current batch

    # this class does some simple data-manipulations
    self.img_augment = SimpleAugment(mean=self.img_mean,shape=params['shape'],
                                     scale = params['scale'])
    #create threadpools for parallel augmentation
    #self.pool = ThreadPool() #4
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号