def def_netF():
vgg19 = M.vgg19()
vgg19.load_state_dict(torch.load('vgg19.pth'))
vgg19.classifier = nn.Sequential(
*list(vgg19.classifier.children())[:2]
)
for param in vgg19.parameters():
param.requires_grad = False
return vgg19
评论列表
文章目录