def Flatten(layer):
"""
Handy function for flattening the result of a conv2D or
maxpool2D to be used for a fully-connected (affine) layer.
"""
layer_shape = layer.get_shape()
# num_features = tf.reduce_prod(tf.shape(layer)[1:])
num_features = layer_shape[1:].num_elements()
layer_flat = tf.reshape(layer, [-1, num_features])
return layer_flat, num_features
layer_utils.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录