def reshape(incoming, new_shape, name="Reshape"):
""" Reshape.
A layer that reshape the incoming layer tensor output to the desired shape.
Arguments:
incoming: A `Tensor`. The incoming tensor.
new_shape: A list of `int`. The desired shape.
name: A name for this layer (optional).
"""
with tf.name_scope(name) as scope:
inference = incoming
if isinstance(inference, list):
inference = tf.concat(0, inference)
inference = tf.cast(inference, tf.float32)
inference = tf.reshape(inference, shape=new_shape)
inference.scope = scope
# Track output tensor.
tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, inference)
return inference
评论列表
文章目录