def maxout2(x):
shape = x.shape
if x.ndim == 1:
shape1 = T.cast(shape[0] / 2, 'int64')
shape2 = T.cast(2, 'int64')
x = x.reshape([shape1, shape2])
x = x.max(1)
elif x.ndim == 2:
shape1 = T.cast(shape[1] / 2, 'int64')
shape2 = T.cast(2, 'int64')
x = x.reshape([shape[0], shape1, shape2])
x = x.max(2)
elif x.ndim == 3:
shape1 = T.cast(shape[2] / 2, 'int64')
shape2 = T.cast(2, 'int64')
x = x.reshape([shape[0], shape[1], shape1, shape2])
x = x.max(3)
return x
评论列表
文章目录