def __init__(self, classes=None, debug=False, arch='vgg16'):
super(FasterRCNN, self).__init__()
if classes is not None:
self.classes = classes
self.n_classes = len(classes)
print('n_classes: {}\n{}'.format(self.n_classes, self.classes))
if arch == 'vgg16':
cnn_arch = models.vgg16(pretrained=False) # w/o bn
self.rpn = RPN(features=cnn_arch.features)
self.fcs = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout()
)
self.roi_pool = RoIPool(7, 7, 1.0/16)
# self.fc6 = FC(512 * 7 * 7, 4096)
# self.fc7 = FC(4096, 4096)
self.score_fc = FC(4096, self.n_classes, relu=False)
self.bbox_fc = FC(4096, self.n_classes * 4, relu=False)
# loss
self.cross_entropy = None
self.loss_box = None
# for log
self.debug = debug
评论列表
文章目录