test_segnet.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:chainer-segnet 作者: pfnet-research 项目源码 文件源码
def test_save_normal_graphs(self):
        x = np.random.uniform(-1, 1, self.x_shape)
        x = Variable(x.astype(np.float32))

        for depth in six.moves.range(1, self.n_encdec + 1):
            model = segnet.SegNet(
                n_encdec=self.n_encdec, in_channel=self.x_shape[1])
            y = model(x, depth)
            cg = build_computational_graph(
                [y],
                variable_style=_var_style,
                function_style=_func_style
            ).dump()
            for e in range(1, self.n_encdec + 1):
                self.assertTrue('encdec{}'.format(e) in model._children)

            fn = 'tests/SegNet_x_depth-{}_{}.dot'.format(self.n_encdec, depth)
            if os.path.exists(fn):
                continue
            with open(fn, 'w') as f:
                f.write(cg)
            subprocess.call(
                'dot -Tpng {} -o {}'.format(
                    fn, fn.replace('.dot', '.png')), shell=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号