def __init__(self, classes=None, debug=False):
super(RFCN, self).__init__()
if classes is not None:
self.classes = classes
self.n_classes = len(classes)
self.rpn = RPN()
#self.psroi_pool = PSRoIPool(7,7,1.0/16,7,15) This is for test
self.psroi_pool_cls = PSRoIPool(7,7, 1.0/16, 7, self.n_classes)
self.psroi_pool_loc = PSRoIPool(7,7, 1.0/16, 7, 8)
self.new_conv = Conv2d(512, 1024, 1, same_padding=False)
self.rfcn_score = Conv2d(1024,7*7*8, 1,1, bn=False)
self.rfcn_bbox = Conv2d(1024, 7*7*self.n_classes,1,1,bn=False)
self.bbox_pred = nn.AvgPool2d((7,7),stride=(7,7))
self.cls_score = nn.AvgPool2d((7,7),stride=(7,7))
# loss
self.cross_entropy = None
self.loss_box = None
# for log
self.debug = debug
评论列表
文章目录