optimizer.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def build_graph(self, weights, loss=None, optimizer=None, norm=False, batch_size=None, grad_ys=None):
        if loss is not None:
            gradients = tf.gradients(loss.node, list(utils.Utils.flatten(weights.node)), grad_ys)
            gradients = [tf.check_numerics(g, 'gradient_%d' % i) for i, g in enumerate(gradients)]
            if batch_size is not None:
                gradients = [g / float(batch_size) for g in gradients]

            # store gradients global norm before clipping
            self.global_norm = tf.global_norm(gradients)

            # clip gradients after global norm has been stored
            if norm:
                gradients, _ = tf.clip_by_global_norm(gradients, norm)
            self.calculate = graph.TfNode(utils.Utils.reconstruct(gradients, weights.node))
        if optimizer is not None:
            self.ph_gradients = graph.Placeholders(weights)
            self.apply = graph.TfNode(optimizer.node.apply_gradients(
                    utils.Utils.izip(self.ph_gradients.checked, weights.node)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号