def __init__(self):
super(EncoderCNN, self).__init__()
self.vgg = models.vgg16()
self.vgg.load_state_dict(torch.load(vgg_checkpoint))
# ?VGG?????fc??????ReLU????????
self.vgg.classifier = nn.Sequential(*list(self.vgg.classifier.children())[:-1])
评论列表
文章目录