def predict_masks(args, hps, store, to_predict: List[str], threshold: float,
validation: str=None, no_edges: bool=False):
logger.info('Predicting {} masks: {}'
.format(len(to_predict), ', '.join(sorted(to_predict))))
model = Model(hps=hps)
if args.model_path:
model.restore_snapshot(args.model_path)
else:
model.restore_last_snapshot(args.logdir)
def load_im(im_id):
data = model.preprocess_image(utils.load_image(im_id))
if hps.n_channels != data.shape[0]:
data = data[:hps.n_channels]
if validation == 'square':
data = square(data, hps)
return Image(id=im_id, data=data)
def predict_mask(im):
logger.info(im.id)
return im, model.predict_image_mask(im.data, no_edges=no_edges)
im_masks = map(predict_mask, utils.imap_fixed_output_buffer(
load_im, sorted(to_predict), threads=2))
for im, mask in utils.imap_fixed_output_buffer(
lambda _: next(im_masks), to_predict, threads=1):
assert mask.shape[1:] == im.data.shape[1:]
with gzip.open(str(mask_path(store, im.id)), 'wb') as f:
# TODO - maybe do (mask * 20).astype(np.uint8)
np.save(f, mask >= threshold)
评论列表
文章目录