net_surgery.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:ifp 作者: morris-frank 项目源码 文件源码
def perform_surgery(inp_proto, inp_model, fcn_proto, fcn_model):
    state = caffe.TRAIN

    # Load the original network and extract the fully connected layers'
    # parameters.
    net = caffe.Net(inp_proto, inp_model, state)
    params = ['fc6', 'fc7_', 'fc8_output']

    #net.blobs['data'].reshape(1, 3, 67, 67)
    # net.reshape()

    # fc_params = {name: (weights, biases)}
    fc_params = {pr: (net.params[pr][0].data, net.params[
                      pr][1].data) for pr in params}
    for fc in params:
        print '{} weights are {} dimensional and biases are {} dimensional'.format(fc, fc_params[fc][0].shape, fc_params[fc][1].shape)

    # Load the fully convolutional network to transplant the parameters.
    net_full_conv = caffe.Net(fcn_proto, inp_model, state)
    params_full_conv = ['fc6-conv', 'fc7-conv', 'fc8-score']

    # conv_params = {name: (weights, biases)}
    conv_params = {pr: (net_full_conv.params[pr][0].data, net_full_conv.params[
                        pr][1].data) for pr in params_full_conv}
    for conv in params_full_conv:
        print '{} weights are {} dimensional and biases are {} dimensional'.format(conv, conv_params[conv][0].shape, conv_params[conv][1].shape)

    for pr, pr_conv in zip(params, params_full_conv):
        print '{} = {}'.format(pr_conv, pr)
        conv_params[pr_conv][0].flat = fc_params[
            pr][0].flat  # flat unrolls the arrays
        conv_params[pr_conv][1][...] = fc_params[pr][1]

    print 'Finished unrolling.....'

    if not os.path.exists('/'.join(fcn_model.split('/')[:-1])):
        os.makedirs('/'.join(fcn_model.split('/')[:-1]))

    net_full_conv.save(fcn_model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号