model.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:chatbot-generative 作者: DeeChat 项目源码 文件源码
def _create_optimizer(self):
        print('Create optimizer... ')
        with tf.variable_scope('training'):
            self.global_step = tf.Variable(
                0, dtype=tf.int32, trainable=False, name='global_step')

            if not self.fw_only:
                self.optimizer = tf.train.GradientDescentOptimizer(config.LR)
                trainable_vars = tf.trainable_variables()
                self.gradient_norms = []
                self.train_ops = []
                start = time.time()
                for bucket_id in range(len(config.BUCKETS)):
                    clipped_grads, norm = tf.clip_by_global_norm(
                        tf.gradients(self.losses[bucket_id], trainable_vars),
                        config.MAX_GRAD_NORM)
                    self.gradient_norms.append(norm)
                    self.train_ops.append(self.optimizer.apply_gradients(
                        zip(clipped_grads, trainable_vars),
                        global_step=self.global_step))
                    print('Creating opt for bucket {:d} took {:.2f} seconds.'.format(
                        bucket_id, time.time() - start))
                    start = time.time()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号