make_submission.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:kaggle-dstl 作者: lopuhin 项目源码 文件源码
def predict_masks(args, hps, store, to_predict: List[str], threshold: float,
                  validation: str=None, no_edges: bool=False):
    logger.info('Predicting {} masks: {}'
                .format(len(to_predict), ', '.join(sorted(to_predict))))
    model = Model(hps=hps)
    if args.model_path:
        model.restore_snapshot(args.model_path)
    else:
        model.restore_last_snapshot(args.logdir)

    def load_im(im_id):
        data = model.preprocess_image(utils.load_image(im_id))
        if hps.n_channels != data.shape[0]:
            data = data[:hps.n_channels]
        if validation == 'square':
            data = square(data, hps)
        return Image(id=im_id, data=data)

    def predict_mask(im):
        logger.info(im.id)
        return im, model.predict_image_mask(im.data, no_edges=no_edges)

    im_masks = map(predict_mask, utils.imap_fixed_output_buffer(
        load_im, sorted(to_predict), threads=2))

    for im, mask in utils.imap_fixed_output_buffer(
            lambda _: next(im_masks), to_predict, threads=1):
        assert mask.shape[1:] == im.data.shape[1:]
        with gzip.open(str(mask_path(store, im.id)), 'wb') as f:
            # TODO - maybe do (mask * 20).astype(np.uint8)
            np.save(f, mask >= threshold)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号