convert_caffe_weights_to_npy.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:flownet2-tf 作者: sampepose 项目源码 文件源码
def main():
    # Create tempfile to hold prototxt
    tmp = tempfile.NamedTemporaryFile(mode='w', delete=True)

    # Parse prototxt and inject `vars`
    proto = open(arch['DEPLOY_PROTOTXT']).readlines()
    for line in proto:
        for key, value in vars.items():
            tag = "$%s$" % key
            line = line.replace(tag, str(value))
        tmp.write(line)
    tmp.flush()

    # Instantiate Caffe Model
    net = caffe.Net(tmp.name, arch['CAFFEMODEL'], caffe.TEST)

    out = {}
    for (caffe_param, tf_param) in arch['PARAMS'].items():
        # Caffe stores weights as (channels_out, channels_in, h, w)
        # but TF expects (h, w, channels_in, channels_out)
        out[tf_param + '/weights'] = net.params[caffe_param][0].data.transpose((2, 3, 1, 0))
        out[tf_param + '/biases'] = net.params[caffe_param][1].data

    np.save(FLAGS.out, out)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号