graph.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:MMdnn 作者: Microsoft 项目源码 文件源码
def compute_output_shapes(self, model):
        sorted_nodes = self.topologically_sorted()
        (tmp_handle, tmp_prototxt) = tempfile.mkstemp(suffix=".prototxt")
        with open(tmp_prototxt, 'w') as f:
            f.write(text_format.MessageToString(model))
        self.prototxt = tmp_prototxt
        if has_pycaffe():
            caffe = get_caffe_resolver().caffe
            net = caffe.Net(tmp_prototxt, caffe.TEST)
            for key, value in net.blobs.items():
                try:
                    node = self.get_node(key)
                    dims = list(value.shape)
                    dims = dims + [1] * (4 - len(dims))
                    node.output_shape = TensorShape(*dims)
                except:
                    continue
            for node in sorted_nodes:
                if node.output_shape is None:
                    node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
            os.close(tmp_handle)
            os.remove(tmp_prototxt)
        else:
            for node in sorted_nodes:
                node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))

    # consider rewrite this function to Network.py
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号