def save_h5(args, net):
# Begin tf session
with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
# load from previous save
if len(args.ckpt_name) > 0:
saver.restore(sess, os.path.join(args.save_dir, args.ckpt_name))
else:
print 'checkpoint name not specified... exiting.'
return
vs = tf.get_collection(tf.GraphKeys.VARIABLES)
vals = sess.run(vs)
exclude = ['learning_rate', 'beta', 'Adam']
with h5py.File(args.h5_name, 'a') as f:
dset = f.create_group('iter00001')
for v, val in safezip(vs, vals):
if all([e not in v.name for e in exclude]):
dset[v.name] = val
# Train network
评论列表
文章目录