def _setupSparse(self, is_distributed, dtype):
with self._maybeWithDevice("/job:ps" if is_distributed else None):
var0 = tf.Variable(
[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]], dtype=dtype)
var1 = tf.Variable(
[[0.0, 1.0], [0.0, 3.0], [0.0, 5.0]], dtype=dtype)
with self._maybeWithDevice("/job:worker" if is_distributed else None):
grads = tf.IndexedSlices(
tf.constant(
[[0.1, 0.1], [0.1, 0.1]], dtype=dtype), [0, 2], [3, 2])
sgd = tf.train.GradientDescentOptimizer(3.0)
clip_opt = VariableClippingOptimizer(
sgd, {var0: [1],
var1: [0]}, 2.0)
update_op = clip_opt.apply_gradients(
list(zip([grads, grads], [var0, var1])))
tf.global_variables_initializer().run()
return var0, var1, update_op
评论列表
文章目录