def __unpool(self, updates, mask, ksize=[1, 2, 2, 1], output_shape=None, feature_count=None, name=''):
with tf.variable_scope(name):
mask = tf.cast(mask, tf.int32)
input_shape = tf.shape(updates, out_type=tf.int32)
# calculation new shape
if feature_count is None:
feature_count = input_shape[3]
if output_shape is None:
output_shape = (1, input_shape[1] * ksize[1], input_shape[2] * ksize[2], feature_count)
output_shape = tf.cast(output_shape, tf.int32)
# calculation indices for batch, height, width and feature maps
one_like_mask = tf.cast(tf.ones_like(mask, dtype=tf.int16), tf.int32)
batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], 0)
batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int32), shape=batch_shape)
b = one_like_mask * batch_range
y = tf.floordiv(mask, output_shape[2] * output_shape[3])
x = tf.mod(tf.floordiv(mask, output_shape[3]), output_shape[2]) #mask % (output_shape[2] * output_shape[3]) // output_shape[3]
feature_range = tf.range(output_shape[3], dtype=tf.int32)
f = one_like_mask * feature_range
# transpose indices & reshape update values to one dimension
updates_size = tf.size(updates)
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
values = tf.reshape(updates, [updates_size])
ret = tf.scatter_nd(indices, values, output_shape)
return ret
评论列表
文章目录