pool.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:lemontree 作者: khshim 项目源码 文件源码
def get_output(self, input_):
        """
        This function overrides the parents' one.
        Creates symbolic function to compute output from an input.

        Parameters
        ----------
        input_: TensorVariable

        Returns
        -------
        TensorVariable
        """
        result = pool_2d(input_,
                         ws=self.input_shape[1:],
                         ignore_border=True,
                         stride=self.input_shape[1:],
                         pad=self.padding,
                         mode='average_exc_pad')  # result is 4D tensor yet, (batch size, output channel, 1, 1)
        return T.reshape(result, (input_.shape[0], input_.shape[1]))  # flatten to 2D matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号