def __init__(self, prev_layers, axis=1):
"""
list of prev layers to concatenate
axis to concatenate
For tensor5, channel dimension is axis=2 (due to theano conv3d
convention). For image, axis=1
"""
assert (len(prev_layers) > 1)
super().__init__(prev_layers[0])
self._axis = axis
self._prev_layers = prev_layers
self._output_shape = self._input_shape.copy()
for prev_layer in prev_layers[1:]:
self._output_shape[axis] += prev_layer._output_shape[axis]
print('Concat the prev layer to [%s]' % ','.join(str(x) for x in self._output_shape))
评论列表
文章目录