def _initialize_weights(self):
vgg16 = torchvision.models.vgg16(pretrained=True)
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
m.weight.data = weight_init.kaiming_normal(m.weight.data)
for a, b in zip(vgg16.features, self.features):
if (isinstance(a, nn.Conv2d) and isinstance(b, nn.Conv2d)):
b.weight.data = a.weight.data
b.bias.data = a.bias.data
for i in [0, 3]:
a, b = vgg16.classifier[i], self.classifier[i]
b.weight.data = a.weight.data.view(b.weight.size())
b.bias.data = a.bias.data.view(b.bias.size())
评论列表
文章目录