def __init__(self, input_, outdim=2, debug=False):
assert outdim >= 1
self._outdim = outdim
input_shape = tuple(input_.get_shape().as_list())
to_flatten = input_shape[self._outdim - 1:]
if any(s is None for s in to_flatten):
flattened = None
else:
flattened = int(np.prod(to_flatten))
self._output_shape = input_shape[1:self._outdim - 1] + (flattened,)
if debug:
util.header('Flatten(new_shape=%s)' % str(self._output_shape))
pre_shape = tf.shape(input_)[:self._outdim - 1:]
to_flatten = tf.reduce_prod(tf.shape(input_)[self._outdim - 1:])
self._output = tf.reshape(input_, tf.concat(0, [pre_shape, tf.pack([to_flatten])]))
评论列表
文章目录