segnet_model.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:woipv 作者: Panaetius 项目源码 文件源码
def __unpool(self, updates, mask, ksize=[1, 2, 2, 1], output_shape=None, feature_count=None, name=''):
        with tf.variable_scope(name):
            mask = tf.cast(mask, tf.int32)
            input_shape = tf.shape(updates, out_type=tf.int32)
            #  calculation new shape

            if feature_count is None:
                feature_count = input_shape[3]

            if output_shape is None:
                output_shape = (1, input_shape[1] * ksize[1], input_shape[2] * ksize[2], feature_count)

            output_shape = tf.cast(output_shape, tf.int32)

            # calculation indices for batch, height, width and feature maps
            one_like_mask = tf.cast(tf.ones_like(mask, dtype=tf.int16), tf.int32)
            batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], 0)
            batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int32), shape=batch_shape)
            b = one_like_mask * batch_range
            y = tf.floordiv(mask, output_shape[2] * output_shape[3])
            x = tf.mod(tf.floordiv(mask, output_shape[3]), output_shape[2]) #mask % (output_shape[2] * output_shape[3]) // output_shape[3]
            feature_range = tf.range(output_shape[3], dtype=tf.int32)
            f = one_like_mask * feature_range
            # transpose indices & reshape update values to one dimension
            updates_size = tf.size(updates)
            indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
            values = tf.reshape(updates, [updates_size])
            ret = tf.scatter_nd(indices, values, output_shape)
            return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号