cplx_gradient_descent_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:complex_tf 作者: woodshop 项目源码 文件源码
def testWithGlobalStep(self):
    for dtype in [tf.complex64]:
      with self.test_session(force_gpu=True):
        with tf.device('/cpu'):
          global_step = tf.Variable(0, trainable=False)
        v0 = [1.0+2.0j, 2.0+1.0j]
        v1 = [3.0-4.0j, 4.0-3.0j]
        g0 = [0.1+0.1j, 0.1-0.1j]
        g1 = [0.01-0.01j, 0.01+0.01j]
        lr = 3.0-1.5j 
        var0 = tf.Variable(v0, dtype=dtype)
        var1 = tf.Variable(v1, dtype=dtype)
        grads0 = tf.constant(g0, dtype=dtype)
        grads1 = tf.constant(g1, dtype=dtype)
        sgd_op = ctf.train.CplxGradientDescentOptimizer(lr).apply_gradients(
            zip([grads0, grads1], [var0, var1]),
            global_step=global_step)
        tf.global_variables_initializer().run()
        # Fetch params to validate initial values
        self.assertAllCloseAccordingToType(v0, var0.eval())
        self.assertAllCloseAccordingToType(v1, var1.eval())
        # Run 1 step of sgd
        sgd_op.run()
        # Validate updated params and global_step
        self.assertAllCloseAccordingToType(
            [v0[0] - lr * g0[0],
             v0[1] - lr * g0[1]], var0.eval())
        self.assertAllCloseAccordingToType(
            [v1[0] - lr * g1[0],
             v1[1] - lr * g1[1]], var1.eval())
        self.assertAllCloseAccordingToType(1, global_step.eval())

  ### Currently no support for sparse complex tensors
  # def testSparseBasic(self):
  #   for dtype in [tf.complex64]:
  #     with self.test_session(force_gpu=True):
  #       var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
  #       var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)
  #       grads0 = tf.IndexedSlices(
  #           tf.constant([0.1], shape=[1, 1], dtype=dtype),
  #           tf.constant([0]),
  #           tf.constant([2, 1]))
  #       grads1 = tf.IndexedSlices(
  #           tf.constant([0.01], shape=[1, 1], dtype=dtype),
  #           tf.constant([1]),
  #           tf.constant([2, 1]))
  #       sgd_op = ctf.train.CplxGradientDescentOptimizer(3.0).apply_gradients(
  #           zip([grads0, grads1], [var0, var1]))
  #       tf.initialize_all_variables().run()
  #       # Fetch params to validate initial values
  #       self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
  #       self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
  #       # Run 1 step of sgd
  #       sgd_op.run()
  #       # Validate updated params
  #       self.assertAllCloseAccordingToType(
  #           [[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
  #       self.assertAllCloseAccordingToType(
  #           [[3.0], [4.0 - 3.0 * 0.01]], var1.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号