dcn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:dcn.tf 作者: beopst 项目源码 文件源码
def replace_features(coarse_features, fine_features, replace_idxs):
    """ Replace fine features with the corresponding coarse features

        Trick.
            use tf.dynamic_stitch ops

    """

    # TODO: simplify indexing 
    def _convert_to_1d_idxs(src_idxs):
        """ Convert 2D idxs to 1D idxs 
            within 1D tensor whose shape is (b*h*w*c)
        """
        batch_idx_len = map_channel.value * map_width.value * map_height.value
        batch_idx_base = [i*batch_idx_len for i in xrange(batch_size.value)]

        batch_1d = map_channel.value * map_width.value * src_idxs[:,0] + \
                   map_channel.value * src_idxs[:,1]
        batch_1d = tf.add(batch_1d,batch_idx_base)

        flat_idxs = [batch_1d+i for i in xrange(map_channel.value)]
        flat_idxs = tf.reshape(tf.transpose(tf.pack(flat_idxs)), [-1])

        return flat_idxs

    batch_size, map_height, map_width, map_channel = coarse_features.get_shape()

    # flatten coarse features
    flat_coarse_features = tf.reshape(coarse_features, [batch_size.value,-1])
    flat_coarse_features = tf.reshape(flat_coarse_features, [-1])


    # flatten fine features
    flat_fine_features = [tf.reshape(i,[-1]) for i in fine_features]
    flat_fine_features = tf.concat(0,flat_fine_features)

    flat_fine_idxs = [_convert_to_1d_idxs(i) for i in replace_idxs]
    flat_fine_idxs = tf.concat(0,flat_fine_idxs)

    # extract coarse features to be replaced
    # this is required for hint-based training
    flat_coarse_replaced = tf.gather(flat_coarse_features, flat_fine_idxs, validate_indices=False)

    merged = tf.dynamic_stitch([tf.range(0,flat_coarse_features.get_shape()[0]),flat_fine_idxs],
            [flat_coarse_features,flat_fine_features])

    merged = tf.reshape(merged,coarse_features.get_shape())

    return merged, flat_coarse_replaced, flat_fine_features
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号