def load_pretrained_npy(faster_rcnn_model, fname):
params = np.load(fname).item()
# vgg16
vgg16_dict = faster_rcnn_model.rpn.features.state_dict()
for name, val in vgg16_dict.items():
# # print name
# # print val.size()
# # print param.size()
if name.find('bn.') >= 0:
continue
i, j = int(name[4]), int(name[6]) + 1
ptype = 'weights' if name[-1] == 't' else 'biases'
key = 'conv{}_{}'.format(i, j)
param = torch.from_numpy(params[key][ptype])
if ptype == 'weights':
param = param.permute(3, 2, 0, 1)
val.copy_(param)
# fc6 fc7
frcnn_dict = faster_rcnn_model.state_dict()
pairs = {'fc6.fc': 'fc6', 'fc7.fc': 'fc7'}
for k, v in pairs.items():
key = '{}.weight'.format(k)
param = torch.from_numpy(params[v]['weights']).permute(1, 0)
frcnn_dict[key].copy_(param)
key = '{}.bias'.format(k)
param = torch.from_numpy(params[v]['biases'])
frcnn_dict[key].copy_(param)
评论列表
文章目录