def load_next_batch(self):
res = {}
#7
lock = Lock()
threads = [self.pool.apply_async(self.load_next_pair,(lock,)) for \
i in range (self.batch_size)]
thread_res = [thread.get() for thread in threads]
res['data_s'] = np.asarray([tri['sketch'] for tri in thread_res])[:,None,:,:]
res['data_i'] = np.asarray([tri['image'] for tri in thread_res])[:,None,:,:]
res['label_s'] = np.asarray([tri['label_s'] for tri in thread_res],dtype=np.float32)[:,None]
res['label_i'] = np.asarray([tri['label_i'] for tri in thread_res],dtype=np.float32)[:,None]
return res
#==============================================================================
# res['data_s'] = np.zeros((self.batch_size,1,self.outshape[0],\
# self.outshape[1]),dtype = np.float32)
# res['data_i'] = np.zeros_like(res['data_a'],dtype=np.float32)
# res['label'] = np.zeros((self.batch_size,1),dtype = np.float32)
# for itt in range(self.batch_size):
# trp = self.load_next_pair(1)
# res['data_s'][itt,...] = trp['sketch']
# res['data_i'][itt,...] = trp['image']
# res['label'][itt,...] = trp['label']
# return res
#==============================================================================
评论列表
文章目录