rbm_train_by_pair_layers.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:rbm_based_autoencoders_with_tensorflow 作者: ikhlestov 项目源码 文件源码
def _train_layer_pair(self):
        self.build_model()
        prev_run_no = self.params.get('run_no', None)
        self.define_runner_folders()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            self.sess = sess

            if prev_run_no:
                print("Restore variables from previous run:")
                restore_vars_dict = self._get_restored_variables_names()
                for var_name in restore_vars_dict.keys():
                    print("\t%s" % var_name)
                restorer = tf.train.Saver(restore_vars_dict)
                restorer.restore(sess, self.saves_path)
                print("Initialize not restored variables:")
                new_variables = self._get_new_variables_names()
                for var in new_variables:
                    print("\t%s" % var.name)
                sess.run(tf.initialize_variables(new_variables))

            else:
                print("Initialize new variables")
                tf.initialize_all_variables().run()
            self.summary_writer = tf.train.SummaryWriter(
                self.logs_dir, sess.graph)
            for epoch in range(self.params['epochs']):
                start = time.time()
                self._epoch_train_step()
                time_cons = time.time() - start
                time_cons = str(datetime.timedelta(seconds=time_cons))
                print("Epoch: %d, time consumption: %s" % (epoch, time_cons))

            # Save all trained variables
            saver = tf.train.Saver()
            saver.save(sess, self.saves_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号