netsurgery.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:material-seg 作者: paulu 项目源码 文件源码
def netsurgery(protofile, paramfile, params, protofile_full_conv, params_full_conv, opath):
    '''
    See https://github.com/BVLC/caffe/blob/master/examples/net_surgery.ipynb for
    details.

    protofile: prototxt of fully-connected model (read-only pathname)
    paramfile: weights of fully-connected model (read-only pathname)
    params: list of fully-connected blob names (list of str)
    protofile_full_conv: prototxt of fully-convolutional model (read-only pathname)
    params_full_conv: list of fully-convolutional blob names (list of str)
    opath: weights of fully-connected model (write pathname)
    '''
    caffe.set_mode_cpu()
    net = caffe.Net(protofile, paramfile, caffe.TEST)
    fc_params = {pr: (net.params[pr][0].data, net.params[pr][1].data) for pr in params}
    for fc in params:
        print('{} weights are {} dimensional and biases are {} dimensional'.format(fc, fc_params[fc][0].shape, fc_params[fc][1].shape))

    net_full_conv = caffe.Net(protofile_full_conv, paramfile, caffe.TEST)
    conv_params = {pr: (net_full_conv.params[pr][0].data, net_full_conv.params[pr][1].data) for pr in params_full_conv}
    for conv in params_full_conv:
        print('{} weights are {} dimensional and biases are {} dimensional'.format(conv, conv_params[conv][0].shape, conv_params[conv][1].shape))

    for pr, pr_conv in zip(params, params_full_conv):
        conv_params[pr_conv][0].flat = fc_params[pr][0].flat  # flat unrolls the arrays
        conv_params[pr_conv][1][...] = fc_params[pr][1]

    net_full_conv.save(opath)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号