networks.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:comprehend 作者: Fenugreek 项目源码 文件源码
def __init__(self, pool_side=2, **kwargs):
        """
        pool_side:
        Do max pooling on pool_side x pool_side non-overlapping
        patches of input.
        """

        Conv.__init__(self, **kwargs)

        if not kwargs.get('fromfile'):
            self.pool_side = pool_side
            self.shapes.append([])

        # Pool shape
        input_size = self.shapes[0] if self.padding == 'SAME' else \
                     [self.shapes[0][i] - self.shapes[1][i] + 1 for i in range(2)]
        self.shapes[2] = [self.batch_size] + \
                         [input_size[i] / self.strides[i+1] /
                          self.pool_side for i in range(2)] + \
                          [self.pool_side**2, self.n_hidden]
        self.zeros = tf.zeros(self.shapes[2], dtype=self.dtype)
        self.state = {}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号