nn.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def __init__(self, input_, outdim=2, debug=False):
        assert outdim >= 1
        self._outdim = outdim
        input_shape = tuple(input_.get_shape().as_list())
        to_flatten = input_shape[self._outdim - 1:]
        if any(s is None for s in to_flatten):
            flattened = None
        else:
            flattened = int(np.prod(to_flatten))

        self._output_shape = input_shape[1:self._outdim - 1] + (flattened,)
        if debug:
            util.header('Flatten(new_shape=%s)' % str(self._output_shape))
        pre_shape = tf.shape(input_)[:self._outdim - 1:]
        to_flatten = tf.reduce_prod(tf.shape(input_)[self._outdim - 1:])
        self._output = tf.reshape(input_, tf.concat(0, [pre_shape, tf.pack([to_flatten])]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号