def train_visualization_seg(self, model, epoch):
image_name_list = sorted(glob(os.path.join(self.flag.data_path,'train/IMAGE/*/*.png')))
print image_name_list
image_name = image_name_list[-1]
image_size = self.flag.image_size
imgInput = cv2.imread(image_name, self.flag.color_mode)
output_path = self.flag.output_dir
input_data = imgInput.reshape((1,image_size,image_size,self.flag.color_mode*2+1))
t_start = cv2.getTickCount()
result = model.predict(input_data, 1)
t_total = (cv2.getTickCount() - t_start) / cv2.getTickFrequency() * 1000
print "[*] Predict Time: %.3f ms"%t_total
imgMask = (result[0]*255).astype(np.uint8)
imgShow = cv2.cvtColor(imgInput, cv2.COLOR_GRAY2BGR)
imgMaskColor = cv2.applyColorMap(imgMask, cv2.COLORMAP_JET)
imgShow = cv2.addWeighted(imgShow, 0.9, imgMaskColor, 0.4, 0.0)
output_path = os.path.join(self.flag.output_dir, '%04d_'%epoch+os.path.basename(image_name))
cv2.imwrite(output_path, imgShow)
# print "SAVE:[%s]"%output_path
# cv2.imwrite(os.path.join(output_path, 'img%04d.png'%epoch), imgShow)
# cv2.namedWindow("show", 0)
# cv2.resizeWindow("show", 800, 800)
# cv2.imshow("show", imgShow)
# cv2.waitKey(1)
callbacks.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录