def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
raise ValueError('`Concatenate` layer should be called '
'on a list of inputs')
if all([shape is None for shape in input_shape]):
return
reduced_inputs_shapes = [
tensor_shape.TensorShape(shape).as_list() for shape in input_shape
]
shape_set = set()
for i in range(len(reduced_inputs_shapes)):
del reduced_inputs_shapes[i][self.axis]
shape_set.add(tuple(reduced_inputs_shapes[i]))
if len(shape_set) > 1:
raise ValueError('`Concatenate` layer requires '
'inputs with matching shapes '
'except for the concat axis. '
'Got inputs shapes: %s' % (input_shape))
self.built = True
评论列表
文章目录