def __init__(self, patchsize, source, binary_mask=None,
random_order=False, mirrored=True, max_num=None):
self.patchsize = patchsize
self.source = source.astype(np.float32)
self.mask = binary_mask
self.random_order = random_order
self.mirrored = mirrored
self.max_num = max_num
if len(self.source.shape)==2:
self.source = self.source[:,:,np.newaxis]
if self.mask is not None and len(self.mask.shape)==2:
self.mask = self.mask[:,:,np.newaxis]
if self.mask is not None:
self.num_patches = (self.mask>0).sum()
else:
self.num_patches = np.product(self.source.shape)
评论列表
文章目录