def train_step(self, blobs, train_op):
self.forward(blobs['data'], blobs['im_info'], blobs['gt_boxes'])
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss = self._losses["rpn_cross_entropy"].data[0], \
self._losses['rpn_loss_box'].data[0], \
self._losses['cross_entropy'].data[0], \
self._losses['loss_box'].data[0], \
self._losses['total_loss'].data[0]
#utils.timer.timer.tic('backward')
train_op.zero_grad()
self._losses['total_loss'].backward()
#utils.timer.timer.toc('backward')
train_op.step()
self.delete_intermediate_states()
return rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, loss
评论列表
文章目录