def flatten(incoming, name="Flatten"):
""" Flatten.
Flatten the incoming Tensor.
Input:
(2+)-D `Tensor`.
Output:
2-D `Tensor` [batch, flatten_dims].
Arguments:
incoming: `Tensor`. The incoming tensor.
"""
input_shape = utils.get_incoming_shape(incoming)
assert len(input_shape) > 1, "Incoming Tensor shape must be at least 2-D"
dims = int(np.prod(input_shape[1:]))
x = reshape(incoming, [-1, dims], name)
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, x)
return x
评论列表
文章目录