def max_pool_with_argmax(net, stride):
"""
Tensorflow default implementation does not provide gradient operation on max_pool_with_argmax
Therefore, we use max_pool_with_argmax to extract mask and
plain max_pool for, eeem... max_pooling.
"""
with tf.name_scope('MaxPoolArgMax'):
_, mask = tf.nn.max_pool_with_argmax(
net,
ksize=[1, stride, stride, 1],
strides=[1, stride, stride, 1],
padding='SAME')
mask = tf.stop_gradient(mask)
net = slim.max_pool2d(net, kernel_size=[stride, stride], stride=FLAGS.pool_size)
return net, mask
# Thank you, @https://github.com/Pepslee
WhatWhereAutoencoder.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录