def get_vgg19(num_classes, pretrained):
net = models.vgg19()
if pretrained:
net.load_state_dict(torch.load(pretrained_vgg19_path))
net.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
return net
评论列表
文章目录