data_layer.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码
def load_next_batch(self):
    res = {}
    #7
    lock = Lock()
    threads = [self.pool.apply_async(self.load_next_image,(lock,)) for \
                i in range (self.batch_size)]
    thread_res = [thread.get() for thread in threads]
    res['data'] = np.asarray([datum[0] for datum in thread_res])[:,None,:,:]
    res['label'] = np.asarray([datum[1] for datum in thread_res],dtype=np.float32)
    return res

#==============================================================================
#     res['data'] = np.zeros((self.batch_size,1,self.outshape[0],self.outshape[1]),dtype = np.float32)
#     res['label'] = np.zeros(self.batch_size,dtype = np.float32)
#     for itt in range(self.batch_size):
#       img, label = self.load_next_image(1)
#       res['data'][itt,...] = img
#       res['label'][itt] = label
#     return res
#==============================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号