def __init__(self, preserve_dims=1, name="batch_flatten"):
"""Constructs a BatchFlatten module.
Args:
preserve_dims: Number of leading dimensions that will not be reshaped.
For example, given an input Tensor with shape `[B, H, W, C]`:
* `preserve_dims=1` will return a Tensor with shape `[B, H*W*C]`.
* `preserve_dims=2` will return a Tensor with
shape `[B, H, W*C]`.
* `preserve_dims=3` will return the input itself,
shape `[B, H, W, C]`.
* `preserve_dims=4` will return a Tensor with
shape `[B, H, W, C, 1]`.
* `preserve_dims>=5` will throw an error on build.
The preserved dimensions can be unknown at building time.
name: Name of the module.
"""
super(BatchFlatten, self).__init__(
shape=(-1,), preserve_dims=preserve_dims, name=name)
评论列表
文章目录