def build_model(self, sess):
self.init_opt()
sess.run(tf.initialize_all_variables())
if len(self.model_path) > 0:
print("Reading model parameters from %s" % self.model_path)
restore_vars = tf.all_variables()
# all_vars = tf.all_variables()
# restore_vars = [var for var in all_vars if
# var.name.startswith('g_') or
# var.name.startswith('d_')]
saver = tf.train.Saver(restore_vars)
saver.restore(sess, self.model_path)
istart = self.model_path.rfind('_') + 1
iend = self.model_path.rfind('.')
counter = self.model_path[istart:iend]
counter = int(counter)
else:
print("Created model with fresh parameters.")
counter = 0
return counter
trainer.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录