def build_mnist_model(self, input, use_unpooling):
"""
Build autoencoder model for mnist dataset as described in the Stacked What-Where autoencoders paper
:param input: 4D tensor of source data of shae [batch_size, w, h, channels]
:param use_unpooling: indicate whether unpooling layer should be used instead of naive upsampling
:return: tuple of tensors:
train - train operation
encode - bottleneck tensor of the autoencoder network
decode - reconstruction of the input
"""
# Encoder. (16)5c-(32)3c-Xp
net = slim.conv2d(input, 16, [5, 5])
net = slim.conv2d(net, 32, [3, 3])
if use_unpooling:
encode, mask = max_pool_with_argmax(net, FLAGS.pool_size)
net = unpool(encode, mask, stride=FLAGS.pool_size)
else:
encode = slim.max_pool2d(net, kernel_size=[FLAGS.pool_size, FLAGS.pool_size], stride=FLAGS.pool_size)
net = upsample(encode, stride=FLAGS.pool_size)
# Decoder
net = slim.conv2d_transpose(net, 16, [3, 3])
net = slim.conv2d_transpose(net, 1, [5, 5])
decode = net
loss_l2 = tf.nn.l2_loss(slim.flatten(input) - slim.flatten(net))
# Optimizer
train = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(loss_l2)
return train, encode, decode
WhatWhereAutoencoder.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录