def testGradWrtRef(self):
for dtype in [tf.complex64]:
with self.test_session(force_gpu=True):
values = [1.0+2.0j, 2.0+1.0j]
lr = 3.0-1.5j
opt = ctf.train.CplxGradientDescentOptimizer(lr)
values = [1.0, 3.0]
vars_ = [tf.Variable([v], dtype=dtype) for v in values]
grads_and_vars = opt.compute_gradients(
vars_[0]._ref() + vars_[1], vars_)
tf.global_variables_initializer().run()
for grad, _ in grads_and_vars:
self.assertAllCloseAccordingToType([1.0], grad.eval())
评论列表
文章目录