def get_pretrained_resnet(new_fc_dim=None):
"""
Fetches a pretrained resnet model (downloading if necessary) and chops off the top linear
layer. If new_fc_dim isn't None, then a new linear layer is added.
:param new_fc_dim:
:return:
"""
resnet152 = models.resnet152(pretrained=True)
del resnet152.fc
if new_fc_dim is not None:
resnet152.fc = nn.Linear(ENCODING_SIZE, new_fc_dim)
_init_fc(resnet152.fc)
else:
resnet152.fc = lambda x: x
return resnet152
评论列表
文章目录