def test_save_normal_graphs(self):
x = np.random.uniform(-1, 1, self.x_shape)
x = Variable(x.astype(np.float32))
for depth in six.moves.range(1, self.n_encdec + 1):
model = segnet.SegNet(
n_encdec=self.n_encdec, in_channel=self.x_shape[1])
y = model(x, depth)
cg = build_computational_graph(
[y],
variable_style=_var_style,
function_style=_func_style
).dump()
for e in range(1, self.n_encdec + 1):
self.assertTrue('encdec{}'.format(e) in model._children)
fn = 'tests/SegNet_x_depth-{}_{}.dot'.format(self.n_encdec, depth)
if os.path.exists(fn):
continue
with open(fn, 'w') as f:
f.write(cg)
subprocess.call(
'dot -Tpng {} -o {}'.format(
fn, fn.replace('.dot', '.png')), shell=True)
评论列表
文章目录