def change_key_names(old_params, in_channels):
new_params = collections.OrderedDict()
layer_count = 0
for layer_key in old_params.keys():
if layer_count < 26:
if layer_count == 0:
rgb_weight = old_params[layer_key]
rgb_weight_mean = torch.mean(rgb_weight, dim=1)
flow_weight = rgb_weight_mean.repeat(1,in_channels,1,1)
new_params[layer_key] = flow_weight
layer_count += 1
# print(layer_key, new_params[layer_key].size())
else:
new_params[layer_key] = old_params[layer_key]
layer_count += 1
# print(layer_key, new_params[layer_key].size())
return new_params
评论列表
文章目录