data_layer_2branchs.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Triplet_Loss_SBIR 作者: TuBui 项目源码 文件源码
def load_next_batch(self):
    res = {}
    #7
    lock = Lock()
    threads = [self.pool.apply_async(self.load_next_pair,(lock,)) for \
                i in range (self.batch_size)]
    thread_res = [thread.get() for thread in threads]
    res['data_s'] = np.asarray([tri['sketch'] for tri in thread_res])[:,None,:,:]
    res['data_i'] = np.asarray([tri['image'] for tri in thread_res])[:,None,:,:]
    res['label_s'] = np.asarray([tri['label_s'] for tri in thread_res],dtype=np.float32)[:,None]
    res['label_i'] = np.asarray([tri['label_i'] for tri in thread_res],dtype=np.float32)[:,None]
    return res
#==============================================================================
#     res['data_s'] = np.zeros((self.batch_size,1,self.outshape[0],\
#                             self.outshape[1]),dtype = np.float32)
#     res['data_i'] = np.zeros_like(res['data_a'],dtype=np.float32)
#     res['label'] = np.zeros((self.batch_size,1),dtype = np.float32)
#     for itt in range(self.batch_size):
#       trp = self.load_next_pair(1)
#       res['data_s'][itt,...] = trp['sketch']
#       res['data_i'][itt,...] = trp['image']
#       res['label'][itt,...] = trp['label']
#     return res
#==============================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号