def create_or_load_model(model, model_dir, session, name):
latest_ckpt = tf.train.latest_checkpoint(model_dir)
if latest_ckpt:
model = load_model(model, latest_ckpt, session, name)
else:
start_time = time.time()
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
print " created %s model with fresh parameters, time %.2fs" % \
(name, time.time() - start_time)
global_step = model.global_step.eval(session=session)
return model, global_step
评论列表
文章目录