union.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Seg 作者: gxd1994 项目源码 文件源码
def get_seg_result(self,det_preds,imgs,imgname_list,stride=128):

        with tf.Graph().as_default():
            image_batch = tf.placeholder(dtype=tf.float32,shape=[None,stride,stride,3],name = 'image_batch_seg')
            is_training = tf.placeholder(tf.bool, shape=[], name='is_training')

            with tf.variable_scope('Seg_Net'):
                seg_net = Seg_Net()
                logits = seg_net.inference(image_batch, is_training)
                seg_pred = seg_net.eval(logits=logits)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                saver = tf.train.Saver()
                saver.restore(sess,SEG_MODEL_PATH)
                print "Restored model parameters from {}".format(SEG_MODEL_PATH)

                all_patch_np_list,all_patch_cood_list = self._get_seg_data(det_preds,imgs,stride)

                for i,(patchs,coods) in enumerate(zip(all_patch_np_list,all_patch_cood_list)):
                    num = int(math.ceil(float(patchs.shape[0])/SEG_BATCH_SIZE))
                    mask = np.zeros(shape=(imgs.shape[1],imgs.shape[2]),dtype=np.uint8)
                    for j in range(num):
                        start = j*SEG_BATCH_SIZE
                        end = min((j+1)*SEG_BATCH_SIZE,patchs.shape[0])
                        input_batch = patchs[start:end]
                        input_coods = coods[start:end]
                        seg_preds, = sess.run([seg_pred],feed_dict={image_batch:input_batch,is_training:False})
                        print 'seg_preds.shape', seg_preds.shape
                        seg_preds = np.squeeze(seg_preds,axis=3)
                        for k in range(seg_preds.shape[0]):
                            y_start = input_coods[k][0]
                            x_start = input_coods[k][1]
                            mask[y_start*stride:(y_start+1)*stride,x_start*stride:(x_start+1)*stride] = seg_preds[k]

                    mask[np.where(mask==1)] = 255
                    # cv2.namedWindow('mask',cv2.WINDOW_NORMAL)
                    # cv2.imshow('mask',mask)
                    # cv2.waitKey(0)
                    if not os.path.exists(SAVE_SEG_DIR):
                        os.makedirs(SAVE_SEG_DIR)
                    cv2.imwrite(os.path.join(SAVE_SEG_DIR,imgname_list[i]),mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号