WhatWhereAutoencoder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Tensorflow_WhatWhereAutoencoder 作者: yselivonchyk 项目源码 文件源码
def max_pool_with_argmax(net, stride):
  """
  Tensorflow default implementation does not provide gradient operation on max_pool_with_argmax
  Therefore, we use max_pool_with_argmax to extract mask and
  plain max_pool for, eeem... max_pooling.
  """
  with tf.name_scope('MaxPoolArgMax'):
    _, mask = tf.nn.max_pool_with_argmax(
      net,
      ksize=[1, stride, stride, 1],
      strides=[1, stride, stride, 1],
      padding='SAME')
    mask = tf.stop_gradient(mask)
    net = slim.max_pool2d(net, kernel_size=[stride, stride],  stride=FLAGS.pool_size)
    return net, mask


# Thank you, @https://github.com/Pepslee
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号