convert_weights.py 文件源码

python
阅读 99 收藏 0 点赞 0 评论 0

项目:c3d-tensorflow2 作者: chuckcho 项目源码 文件源码
def main():
    # Per https://www.tensorflow.org/versions/r0.11/api_docs/python/nn.html#conv3d
    # Filter has shape: [filter_depth, filter_height, filter_width, in_channels, out_channels]
    net = caffe.Net(model, weights)
    netdata = dict()
    for layer in layers:
        print "{}: w_shape, b_shape={}, {}".format(layer, net.params[layer][0].data.shape, net.params[layer][1].data.shape)
        # per https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.conv3d.md
        # filter: A Tensor. Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels]. in_channels must match between input and filter.
        if 'conv' in layer:
            w = np.transpose(net.params[layer][0].data, (2, 3, 4, 1, 0))
        elif 'fc' in layer:
            w = net.params[layer][0].data[0, 0, 0, :, :].T
        b = np.squeeze(net.params[layer][1].data)
        netdata.update({layer: (w, b)})
    np.save(output, netdata)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号