colordata.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch_divcolor 作者: aditya12agd5 项目源码 文件源码
def __tiledoutput__(self, net_op, batch_size, num_cols=8, net_recon_const=None):

    num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
    out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8')
    img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8')
    c = 0
    r = 0

    for i in range(batch_size):
      if(i % num_cols == 0 and i > 0):
        r = r + 1
        c = 0
      img_lab[..., 0] = self.__decodeimg__(net_recon_const[i, 0, :, :].reshape(\
        self.outshape[0], self.outshape[1]))
      img_lab[..., 1] = self.__decodeimg__(net_op[i, 0, :, :].reshape(\
        self.shape[0], self.shape[1]))
      img_lab[..., 2] = self.__decodeimg__(net_op[i, 1, :, :].reshape(\
        self.shape[0], self.shape[1]))
      img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
      out_img[r*self.outshape[0]:(r+1)*self.outshape[0], \
        c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb
      c = c+1

    return out_img
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号