trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DialogueBreakdownDetection2016 作者: icoxfog417 项目源码 文件源码
def set_optimizer(self, session, learning_rate=0.1, learning_rate_decay_factor=0.99, max_gradient_norm=5.0, load_if_exist=True):
        self.global_step = tf.Variable(0, trainable=False)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.learning_rate_opr = self.learning_rate.assign(self.learning_rate * learning_rate_decay_factor)
        self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)

        self.outputs, self.losses = self.calc_loss()

        params = tf.trainable_variables()
        for b in range(len(self.buckets)):
            gradients = tf.gradients(self.losses[b], params)
            clipped_gradients, norm = tf.clip_by_global_norm(gradients, max_gradient_norm)

            self.gradient_norms.append(norm)
            self.updates.append(self.optimizer.apply_gradients(zip(clipped_gradients, params), global_step=self.global_step))

        self.saver = tf.train.Saver(tf.all_variables())
        session.run(tf.initialize_all_variables())
        if load_if_exist and self.train_dir:
            saved = tf.train.get_checkpoint_state(self.train_dir)
            if saved and tf.gfile.Exists(saved.model_checkpoint_path):
                self.saver.restore(session, saved.model_checkpoint_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号