test_var_clip_opt.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _setupSparse(self, is_distributed, dtype):
        with self._maybeWithDevice("/job:ps" if is_distributed else None):
            var0 = tf.Variable(
                [[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], dtype=dtype)
            var1 = tf.Variable(
                [[0.0, 1.0], [0.0, 3.0], [0.0, 5.0]], dtype=dtype)
        with self._maybeWithDevice("/job:worker" if is_distributed else None):
            grads = tf.IndexedSlices(
                tf.constant(
                    [[0.1, 0.1], [0.1, 0.1]], dtype=dtype), [0, 2], [3, 2])
            sgd = tf.train.GradientDescentOptimizer(3.0)
            clip_opt = VariableClippingOptimizer(
                sgd, {var0: [1],
                      var1: [0]}, 2.0)
            update_op = clip_opt.apply_gradients(
                list(zip([grads, grads], [var0, var1])))
            tf.global_variables_initializer().run()
        return var0, var1, update_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号