save_weights.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:tf_face 作者: ZhijianChan 项目源码 文件源码
def main(args):
    if args.meta_file == None or not os.path.exists(args.meta_file):
        print("Invalid tensorflow meta-graph file:", args.meta_file)
        return

    gpu_options = tf.GPUOptions(allow_growth=True)
    sess = tf.Session(config=tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=False,
        allow_soft_placement=True))
    with sess.as_default():
        # ---- load pretrained parameters ---- #
        saver = tf.train.import_meta_graph(args.meta_file, clear_devices=True)
        saver.restore(tf.get_default_session(), args.ckpt_file)
        pretrained = {}
        var_ = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
        print("total:", len(var_))
        for v in var_:
            print("process:", v.name)
            # [notice: the name of parameter is like 'Resnet/conv2d/bias:0',
            #  here we should remove the prefix name, and get '/conv2d/bias:0']
            v_name = v.name
            pretrained[v_name] = sess.run([v])
    np.save(args.save_path, pretrained)
    print("done:", len(pretrained.keys()))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号