def __init__(self, incoming, shape, **kwargs):
super(ReshapeLayer, self).__init__(incoming, **kwargs)
shape = tuple(shape)
for s in shape:
if isinstance(s, int):
if s == 0 or s < - 1:
raise ValueError("`shape` integers must be positive or -1")
elif isinstance(s, list):
if len(s) != 1 or not isinstance(s[0], int) or s[0] < 0:
raise ValueError("`shape` input references must be "
"single-element lists of int >= 0")
elif isinstance(s, (tf.Tensor, tf.Variable)): # T.TensorVariable):
raise NotImplementedError
# if s.ndim != 0:
# raise ValueError(
# "A symbolic variable in a shape specification must be "
# "a scalar, but had %i dimensions" % s.ndim)
else:
raise ValueError("`shape` must be a tuple of int and/or [int]")
if sum(s == -1 for s in shape) > 1:
raise ValueError("`shape` cannot contain multiple -1")
self.shape = shape
# try computing the output shape once as a sanity check
self.get_output_shape_for(self.input_shape)
评论列表
文章目录