def decode(y, relu_max):
assert len(y._keras_shape) == 2
latent_dim = y._keras_shape[-1]
x = Reshape((1, 1, latent_dim))(y)
# 1, 1, latent_dim
if relu_max:
x = Activation(utils.scale_up(relu_max))(x)
# not good? x = BN(mode=2, axis=3)(x)
batch_size = tf.shape(x)[0]
x = Deconv2D(40, 7, 7, output_shape=[batch_size, 7, 7, 40], activation='relu',
border_mode='same', subsample=(7,7))(x)
x = BN(mode=2, axis=3)(x)
# 7, 7, 40
x = Deconv2D(20, 3, 3, output_shape=[batch_size, 14, 14, 20], activation='relu',
border_mode='same', subsample=(2,2))(x)
x = BN(mode=2, axis=3)(x)
# 14, 14, 20
x = Deconv2D(1, 3, 3, output_shape=[batch_size, 28, 28, 1], activation='sigmoid',
border_mode='same', subsample=(2,2))(x)
# 28, 28, 1
return x
评论列表
文章目录