def __init__(self,embedding_size,num_classes, checkpoint=None):
super(FaceModelCenter, self).__init__()
self.model = resnet18()
self.model.avgpool = None
self.model.fc1 = nn.Linear(512*3*3, 512)
self.model.fc2 = nn.Linear(512, embedding_size)
self.model.classifier = nn.Linear(embedding_size, num_classes)
self.centers = torch.zeros(num_classes, embedding_size).type(torch.FloatTensor)
self.num_classes = num_classes
self.apply(self.weights_init)
if checkpoint is not None:
# Check if there are the same number of classes
if list(checkpoint['state_dict'].values())[-1].size(0) == num_classes:
self.load_state_dict(checkpoint['state_dict'])
self.centers = checkpoint['centers']
else:
own_state = self.state_dict()
for name, param in checkpoint['state_dict'].items():
if "classifier" not in name:
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
评论列表
文章目录