def __init__(self, pool_side=2, **kwargs):
"""
pool_side:
Do max pooling on pool_side x pool_side non-overlapping
patches of input.
"""
Conv.__init__(self, **kwargs)
if not kwargs.get('fromfile'):
self.pool_side = pool_side
self.shapes.append([])
# Pool shape
input_size = self.shapes[0] if self.padding == 'SAME' else \
[self.shapes[0][i] - self.shapes[1][i] + 1 for i in range(2)]
self.shapes[2] = [self.batch_size] + \
[input_size[i] / self.strides[i+1] /
self.pool_side for i in range(2)] + \
[self.pool_side**2, self.n_hidden]
self.zeros = tf.zeros(self.shapes[2], dtype=self.dtype)
self.state = {}
评论列表
文章目录