behavioral_cloning.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:gail-driver 作者: sisl 项目源码 文件源码
def save_h5(args, net):
    # Begin tf session
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())

        # load from previous save
        if len(args.ckpt_name) > 0:
            saver.restore(sess, os.path.join(args.save_dir, args.ckpt_name))
        else:
            print 'checkpoint name not specified... exiting.'
            return

        vs = tf.get_collection(tf.GraphKeys.VARIABLES)
        vals = sess.run(vs)
        exclude = ['learning_rate', 'beta', 'Adam']

        with h5py.File(args.h5_name, 'a') as f:
            dset = f.create_group('iter00001')
            for v, val in safezip(vs, vals):
                if all([e not in v.name for e in exclude]):
                    dset[v.name] = val

# Train network
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号