def unpool(net, mask, stride=2):
assert mask is not None
with tf.name_scope('UnPool2D'):
ksize = [1, stride, stride, 1]
input_shape = net.get_shape().as_list()
# calculation new shape
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
# calculation indices for batch, height, width and feature maps
one_like_mask = tf.ones_like(mask)
batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1])
b = one_like_mask * batch_range
y = mask // (output_shape[2] * output_shape[3])
x = mask % (output_shape[2] * output_shape[3]) // output_shape[3]
feature_range = tf.range(output_shape[3], dtype=tf.int64)
f = one_like_mask * feature_range
# transpose indices & reshape update values to one dimension
updates_size = tf.size(net)
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
values = tf.reshape(net, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret
评论列表
文章目录