def compare_io(X_val, model_raw, directory = '%s/image_comparison' % cfg['dir']['validation']):
#Creates tensors to compare to source images, plots both side by side, and saves the plots
road = Roadgen(cfg)
curves_to_print = model_raw.shape[0]
#reshaping the model output vector to make it easier to work with
model_out = road.model_interpret(model_raw)
print('predictions denormalized')
#initialize the model view tensor
model_view = np.zeros( (curves_to_print, road.input_size[1], road.input_size[0],
road.n_channels), dtype=np.uint8)
for prnt_i in range(curves_to_print):
patch = road.road_generator(model_out[prnt_i], road.line_width, rand_gen=0)
model_view[prnt_i] = cv2.resize(patch, road.input_size, interpolation=cv2.INTER_AREA)
road.save_images(X_val, model_view, directory )
#Prints plot of curves against the training data Saves plots in files
#Because the model outputs are not put into a drawing function it is easier for audiences
# to understand the model output data.
#FIXME function is still built to work like v1 generation also may have bugs in the plotter function
评论列表
文章目录