RPN.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def __init__(self, use_kmeans_anchors=False):
        super(RPN, self).__init__()

        if use_kmeans_anchors:
            print 'using k-means anchors'
            self.anchor_scales = self.anchor_scales_kmeans
            self.anchor_ratios = self.anchor_ratios_kmeans
            self.anchor_scales_region = self.anchor_scales_kmeans_region
            self.anchor_ratios_region = self.anchor_ratios_kmeans_region
        else:
            print 'using normal anchors'
            self.anchor_scales, self.anchor_ratios = \
                np.meshgrid(self.anchor_scales_normal, self.anchor_ratios_normal, indexing='ij')
            self.anchor_scales = self.anchor_scales.reshape(-1)
            self.anchor_ratios = self.anchor_ratios.reshape(-1)
            self.anchor_scales_region, self.anchor_ratios_region = \
                np.meshgrid(self.anchor_scales_normal_region, self.anchor_ratios_normal_region, indexing='ij')
            self.anchor_scales_region = self.anchor_scales_region.reshape(-1)
            self.anchor_ratios_region = self.anchor_ratios_region.reshape(-1)

        self.anchor_num = len(self.anchor_scales)
        self.anchor_num_region = len(self.anchor_scales_region)

        # self.features = VGG16(bn=False)
        self.features = models.vgg16(pretrained=True).features
        self.features.__delattr__('30') # to delete the max pooling
        # by default, fix the first four layers
        network.set_trainable_param(list(self.features.parameters())[:8], requires_grad=False) 

        # self.features = models.vgg16().features
        self.conv1 = Conv2d(512, 512, 3, same_padding=True)
        self.score_conv = Conv2d(512, self.anchor_num * 2, 1, relu=False, same_padding=False)
        self.bbox_conv = Conv2d(512, self.anchor_num * 4, 1, relu=False, same_padding=False)

        self.conv1_region = Conv2d(512, 512, 3, same_padding=True)
        self.score_conv_region = Conv2d(512, self.anchor_num_region * 2, 1, relu=False, same_padding=False)
        self.bbox_conv_region = Conv2d(512, self.anchor_num_region * 4, 1, relu=False, same_padding=False)

        # loss
        self.cross_entropy = None
        self.loss_box = None
        self.cross_entropy_region = None
        self.loss_box_region = None

        # initialize the parameters
        self.initialize_parameters()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号