def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'):
super(ResNet200, self).__init__()
self.drop_rate = drop_rate
self.global_pool = global_pool
self.features = fbresnet200_features(activation_fn=activation_fn)
self.fc = nn.Linear(2048 * pooling_factor(global_pool), num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
评论列表
文章目录