def vis_graph(graph, name='net2net', show=False):
path = osp.dirname(name)
name = osp.basename(name)
if path == '':
path = name
mkdir_p(osp.join(root_dir, "output", path), delete=False)
restore_path = os.getcwd()
os.chdir(osp.join(root_dir, "output", path))
with open(name + "_graph.json", "w") as f:
f.write(graph.to_json())
try:
plt.close('all')
nx.draw(graph, with_labels=True)
if show:
plt.show()
plt.savefig('graph.png')
# plt.close('all')
except Exception as inst:
logger.warning(inst)
os.chdir(restore_path)
评论列表
文章目录