def forward(self, x, targets=None, num_iter=0):
conv1s = self.conv1s(x)
conv2 = self.conv2(conv1s)
conv3 = self.conv3(conv2)
conv1s_reorg = self.conv_reorg(conv1s)
conv1s_reorg = self.reorg(conv1s_reorg)
cat_1_3 = torch.cat([conv1s_reorg, conv3], 1)
conv4 = self.conv4(cat_1_3)
output = self.conv5(conv4)
batchsize, _, self.H, self.W = output.size()
# output shape: (batchsize, H*W*num_anchor, (num_class+num_loc))
output = output.permute(0, 2, 3, 1).contiguous().view(batchsize, -1, (self.num_class+self.num_loc))
bbox_delta = output[:, :, :4].contiguous()
iou_pred = F.sigmoid(output[:, :, 4]).contiguous()
class_pred = output[:, :, 5:].contiguous()
prob_pred = F.softmax(class_pred.view(-1, self.num_class)).view_as(class_pred)
pred = (bbox_delta, iou_pred, prob_pred)
self.anchors_cfg[:, 0::2] = self.anchors_cfg[:, 0::2] / self.W
self.anchors_cfg[:, 1::2] = self.anchors_cfg[:, 1::2] / self.H
if self.phase == 'train':
self._calc_loss(pred, targets, num_iter)
else:
assert batchsize == 1, "now only support batchsize=1"
anchors = self._generate_anchors()
bbox_pred = self._generate_pred_bbox(bbox_delta[0], anchors)
output = self.detect(bbox_pred, iou_pred.view(-1), prob_pred.view(-1, self.num_class))
return output
评论列表
文章目录