def get_seg_result(self,det_preds,imgs,imgname_list,stride=128):
with tf.Graph().as_default():
image_batch = tf.placeholder(dtype=tf.float32,shape=[None,stride,stride,3],name = 'image_batch_seg')
is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
with tf.variable_scope('Seg_Net'):
seg_net = Seg_Net()
logits = seg_net.inference(image_batch, is_training)
seg_pred = seg_net.eval(logits=logits)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess,SEG_MODEL_PATH)
print "Restored model parameters from {}".format(SEG_MODEL_PATH)
all_patch_np_list,all_patch_cood_list = self._get_seg_data(det_preds,imgs,stride)
for i,(patchs,coods) in enumerate(zip(all_patch_np_list,all_patch_cood_list)):
num = int(math.ceil(float(patchs.shape[0])/SEG_BATCH_SIZE))
mask = np.zeros(shape=(imgs.shape[1],imgs.shape[2]),dtype=np.uint8)
for j in range(num):
start = j*SEG_BATCH_SIZE
end = min((j+1)*SEG_BATCH_SIZE,patchs.shape[0])
input_batch = patchs[start:end]
input_coods = coods[start:end]
seg_preds, = sess.run([seg_pred],feed_dict={image_batch:input_batch,is_training:False})
print 'seg_preds.shape', seg_preds.shape
seg_preds = np.squeeze(seg_preds,axis=3)
for k in range(seg_preds.shape[0]):
y_start = input_coods[k][0]
x_start = input_coods[k][1]
mask[y_start*stride:(y_start+1)*stride,x_start*stride:(x_start+1)*stride] = seg_preds[k]
mask[np.where(mask==1)] = 255
# cv2.namedWindow('mask',cv2.WINDOW_NORMAL)
# cv2.imshow('mask',mask)
# cv2.waitKey(0)
if not os.path.exists(SAVE_SEG_DIR):
os.makedirs(SAVE_SEG_DIR)
cv2.imwrite(os.path.join(SAVE_SEG_DIR,imgname_list[i]),mask)
评论列表
文章目录