WhatWhereAutoencoder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Tensorflow_WhatWhereAutoencoder 作者: yselivonchyk 项目源码 文件源码
def build_mnist_model(self, input, use_unpooling):
    """
    Build autoencoder model for mnist dataset as described in the Stacked What-Where autoencoders paper

    :param input: 4D tensor of source data of shae [batch_size, w, h, channels]
    :param use_unpooling: indicate whether unpooling layer should be used instead of naive upsampling
    :return: tuple of tensors:
      train - train operation
      encode - bottleneck tensor of the autoencoder network
      decode - reconstruction of the input
    """
    # Encoder. (16)5c-(32)3c-Xp
    net = slim.conv2d(input, 16, [5, 5])
    net = slim.conv2d(net, 32, [3, 3])

    if use_unpooling:
      encode, mask = max_pool_with_argmax(net, FLAGS.pool_size)
      net = unpool(encode, mask, stride=FLAGS.pool_size)
    else:
      encode = slim.max_pool2d(net, kernel_size=[FLAGS.pool_size, FLAGS.pool_size], stride=FLAGS.pool_size)
      net = upsample(encode, stride=FLAGS.pool_size)

    # Decoder
    net = slim.conv2d_transpose(net, 16, [3, 3])
    net = slim.conv2d_transpose(net, 1, [5, 5])
    decode = net

    loss_l2 = tf.nn.l2_loss(slim.flatten(input) - slim.flatten(net))

    # Optimizer
    train = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(loss_l2)
    return train, encode, decode
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号