flow_resnet.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:two-stream-pytorch 作者: bryanyzhu 项目源码 文件源码
def change_key_names(old_params, in_channels):
    new_params = collections.OrderedDict()
    layer_count = 0
    allKeyList = old_params.keys()
    for layer_key in allKeyList:
        if layer_count >= len(allKeyList)-2:
            # exclude fc layers
            continue
        else:
            if layer_count == 0:
                rgb_weight = old_params[layer_key]
                # print(type(rgb_weight))
                rgb_weight_mean = torch.mean(rgb_weight, dim=1)
                # TODO: ugly fix here, why torch.mean() turn tensor to Variable
                # print(type(rgb_weight_mean))
                flow_weight = rgb_weight_mean.unsqueeze(1).repeat(1,in_channels,1,1)
                new_params[layer_key] = flow_weight
                layer_count += 1
                # print(layer_key, new_params[layer_key].size(), type(new_params[layer_key]))
            else:
                new_params[layer_key] = old_params[layer_key]
                layer_count += 1
                # print(layer_key, new_params[layer_key].size(), type(new_params[layer_key]))

    return new_params
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号