def _train_layer_pair(self):
self.build_model()
prev_run_no = self.params.get('run_no', None)
self.define_runner_folders()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
self.sess = sess
if prev_run_no:
print("Restore variables from previous run:")
restore_vars_dict = self._get_restored_variables_names()
for var_name in restore_vars_dict.keys():
print("\t%s" % var_name)
restorer = tf.train.Saver(restore_vars_dict)
restorer.restore(sess, self.saves_path)
print("Initialize not restored variables:")
new_variables = self._get_new_variables_names()
for var in new_variables:
print("\t%s" % var.name)
sess.run(tf.initialize_variables(new_variables))
else:
print("Initialize new variables")
tf.initialize_all_variables().run()
self.summary_writer = tf.train.SummaryWriter(
self.logs_dir, sess.graph)
for epoch in range(self.params['epochs']):
start = time.time()
self._epoch_train_step()
time_cons = time.time() - start
time_cons = str(datetime.timedelta(seconds=time_cons))
print("Epoch: %d, time consumption: %s" % (epoch, time_cons))
# Save all trained variables
saver = tf.train.Saver()
saver.save(sess, self.saves_path)
rbm_train_by_pair_layers.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录