cplx_momentum_test.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:complex_tf 作者: woodshop 项目源码 文件源码
def testNesterovMomentum(self):
    for dtype in [tf.complex64]:
      with self.test_session():
        var0 = tf.Variable([1.0, 2.0], dtype=dtype)
        var1 = tf.Variable([3.0, 4.0], dtype=dtype)
        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
        accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
        accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
        cost = 5 * var0 * var0 + 3 * var1
        global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
        mom_op = ctf.train.CplxMomentumOptimizer(learning_rate=2.0, momentum=0.9,
            use_nesterov=True)
        opt_op = mom_op.minimize(cost, global_step, [var0, var1])
        tf.global_variables_initializer().run()
        for t in range(1, 5):
          opt_op.run()
          var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
              accum0_np, var0_np * 10, 2.0, 0.9)
          var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
              accum1_np, 3, 2.0, 0.9)
          self.assertAllClose(var0_np, var0.eval())
          self.assertAllClose(var1_np, var1.eval())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号