cplx_gradient_descent_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:complex_tf 作者: woodshop 项目源码 文件源码
def testGradWrtRef(self):
    for dtype in [tf.complex64]:
      with self.test_session(force_gpu=True):
        values = [1.0+2.0j, 2.0+1.0j]
        lr = 3.0-1.5j 
        opt = ctf.train.CplxGradientDescentOptimizer(lr)
        values = [1.0, 3.0]
        vars_ = [tf.Variable([v], dtype=dtype) for v in values]
        grads_and_vars = opt.compute_gradients(
          vars_[0]._ref() + vars_[1], vars_)
        tf.global_variables_initializer().run()
        for grad, _ in grads_and_vars:
          self.assertAllCloseAccordingToType([1.0], grad.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号