def draw_net(net_proto_file, out_img_file, style = 'TB'):
"""
draw cnn network into image.
IN: net_proto_file net definition file
IN: style 'TB' for top-> bottom, 'LR' for lelf->right
OUT: out_img_file output image
"""
net = caffe_pb2.NetParameter()
text_format.Merge(open(net_proto_file).read(), net)
if not net.name:
net.name = 'cnn_net'
print('\nDrawing net to %s' % out_img_file)
caffe.draw.draw_net_to_file(net, out_img_file, style)
评论列表
文章目录