def save_output(self, net_op, batch_size, num_cols=8, net_recon_const=None):
num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8')
img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8')
c = 0
r = 0
for i in range(batch_size):
if(i % num_cols == 0 and i > 0):
r = r + 1
c = 0
img_lab[..., 0] = self.__get_decoded_img(net_recon_const[i, ...].reshape(self.outshape[0], self.outshape[1]))
img_lab[..., 1] = self.__get_decoded_img(net_op[i, :np.prod(self.shape)].reshape(self.shape[0], self.shape[1]))
img_lab[..., 2] = self.__get_decoded_img(net_op[i, np.prod(self.shape):].reshape(self.shape[0], self.shape[1]))
img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
out_img[r*self.outshape[0]:(r+1)*self.outshape[0], c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb
c = c+1
return out_img
评论列表
文章目录