def Unpool_layer(self, pool, ind, k_size=[1, 2, 2, 1]):
# https://github.com/tensorflow/tensorflow/issues/2169
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
k_size: k_size is the same as for the pool
Return:
unpool: unpooling tensor
"""
with tf.name_scope('Unpool'):
input_shape = tf.shape(pool)
input_shape_aslist = pool.get_shape().as_list()
output_shape = tf.stack([input_shape[0], input_shape[1] * k_size[1], input_shape[2] * k_size[2], input_shape[3]])
output_shapeaslist = [-1, input_shape_aslist[1]* k_size[1] , input_shape_aslist[2] * k_size[2], input_shape_aslist[3]]
pool_ = tf.reshape(pool, [input_shape_aslist[1] * input_shape_aslist[2] * input_shape_aslist[3]])
batch_range = tf.reshape(tf.range(tf.cast(input_shape[0], tf.int64), dtype=ind.dtype),
shape=tf.stack([input_shape[0], 1, 1, 1]))
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, tf.stack([ input_shape_aslist[1] * input_shape_aslist[2] * input_shape_aslist[3], 1]))
ind_ = tf.reshape(ind, tf.stack( [input_shape_aslist[1] * input_shape_aslist[2] * input_shape_aslist[3], 1]))
ind_ = tf.concat([b, ind_], 1)
ret = tf.scatter_nd(ind_, pool_, shape=tf.cast([input_shape[0], output_shapeaslist[1] * output_shapeaslist[2] * output_shapeaslist[3] ], tf.int64))
ret = tf.reshape(ret, [-1, output_shapeaslist[1], output_shapeaslist[2], output_shapeaslist[3]])
return ret
评论列表
文章目录