def main(args):
if args.meta_file == None or not os.path.exists(args.meta_file):
print("Invalid tensorflow meta-graph file:", args.meta_file)
return
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(
gpu_options=gpu_options,
log_device_placement=False,
allow_soft_placement=True))
with sess.as_default():
# ---- load pretrained parameters ---- #
saver = tf.train.import_meta_graph(args.meta_file, clear_devices=True)
saver.restore(tf.get_default_session(), args.ckpt_file)
pretrained = {}
var_ = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
print("total:", len(var_))
for v in var_:
print("process:", v.name)
# [notice: the name of parameter is like 'Resnet/conv2d/bias:0',
# here we should remove the prefix name, and get '/conv2d/bias:0']
v_name = v.name
pretrained[v_name] = sess.run([v])
np.save(args.save_path, pretrained)
print("done:", len(pretrained.keys()))
评论列表
文章目录