seq2seq_model.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: eske 项目源码 文件源码
def get_update_op(self, loss, opts, global_step=None, max_gradient_norm=None, freeze_variables=None):
        if loss is None:
            return None

        freeze_variables = freeze_variables or []

        # compute gradient only for variables that are not frozen
        frozen_parameters = [var.name for var in tf.trainable_variables()
                             if any(re.match(var_, var.name) for var_ in freeze_variables)]
        params = [var for var in tf.trainable_variables() if var.name not in frozen_parameters]
        self.params = params

        gradients = tf.gradients(loss, params)
        if max_gradient_norm:
            gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)

        update_ops = []
        for opt in opts:
            with tf.variable_scope('gradients' if self.name is None else 'gradients_{}'.format(self.name)):
                update_op = opt.apply_gradients(list(zip(gradients, params)), global_step=global_step)

            update_ops.append(update_op)

        return update_ops
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号