def change_key_names(old_params, in_channels):
new_params = collections.OrderedDict()
layer_count = 0
allKeyList = old_params.keys()
for layer_key in allKeyList:
if layer_count >= len(allKeyList)-2:
# exclude fc layers
continue
else:
if layer_count == 0:
rgb_weight = old_params[layer_key]
# print(type(rgb_weight))
rgb_weight_mean = torch.mean(rgb_weight, dim=1)
# TODO: ugly fix here, why torch.mean() turn tensor to Variable
# print(type(rgb_weight_mean))
flow_weight = rgb_weight_mean.unsqueeze(1).repeat(1,in_channels,1,1)
new_params[layer_key] = flow_weight
layer_count += 1
# print(layer_key, new_params[layer_key].size(), type(new_params[layer_key]))
else:
new_params[layer_key] = old_params[layer_key]
layer_count += 1
# print(layer_key, new_params[layer_key].size(), type(new_params[layer_key]))
return new_params
评论列表
文章目录