yolov2.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:yolov2 作者: zhangkaij 项目源码 文件源码
def forward(self, x, targets=None, num_iter=0):

        conv1s = self.conv1s(x)
        conv2 = self.conv2(conv1s)
        conv3 = self.conv3(conv2)

        conv1s_reorg = self.conv_reorg(conv1s)
        conv1s_reorg = self.reorg(conv1s_reorg)

        cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
        conv4 = self.conv4(cat_1_3)
        output = self.conv5(conv4)
        batchsize, _, self.H, self.W = output.size()
        # output shape: (batchsize, H*W*num_anchor, (num_class+num_loc))
        output = output.permute(0, 2, 3, 1).contiguous().view(batchsize, -1, (self.num_class+self.num_loc))

        bbox_delta = output[:, :, :4].contiguous()
        iou_pred = F.sigmoid(output[:, :, 4]).contiguous()
        class_pred = output[:, :, 5:].contiguous()
        prob_pred = F.softmax(class_pred.view(-1, self.num_class)).view_as(class_pred)
        pred = (bbox_delta, iou_pred, prob_pred)

        self.anchors_cfg[:, 0::2] = self.anchors_cfg[:, 0::2] / self.W
        self.anchors_cfg[:, 1::2] = self.anchors_cfg[:, 1::2] / self.H

        if self.phase == 'train':
            self._calc_loss(pred, targets, num_iter)
        else:
            assert batchsize == 1, "now only support batchsize=1"

            anchors = self._generate_anchors()
            bbox_pred = self._generate_pred_bbox(bbox_delta[0], anchors)
            output = self.detect(bbox_pred, iou_pred.view(-1), prob_pred.view(-1, self.num_class))
            return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号