def compute_output_shapes(self, model):
sorted_nodes = self.topologically_sorted()
(tmp_handle, tmp_prototxt) = tempfile.mkstemp(suffix=".prototxt")
with open(tmp_prototxt, 'w') as f:
f.write(text_format.MessageToString(model))
self.prototxt = tmp_prototxt
if has_pycaffe():
caffe = get_caffe_resolver().caffe
net = caffe.Net(tmp_prototxt, caffe.TEST)
for key, value in net.blobs.items():
try:
node = self.get_node(key)
dims = list(value.shape)
dims = dims + [1] * (4 - len(dims))
node.output_shape = TensorShape(*dims)
except:
continue
for node in sorted_nodes:
if node.output_shape is None:
node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
os.close(tmp_handle)
os.remove(tmp_prototxt)
else:
for node in sorted_nodes:
node.output_shape = TensorShape(*NodeKind.compute_output_shape(node))
# consider rewrite this function to Network.py
评论列表
文章目录