def testNesterovMomentum(self):
for dtype in [tf.complex64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
cost = 5 * var0 * var0 + 3 * var1
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
mom_op = ctf.train.CplxMomentumOptimizer(learning_rate=2.0, momentum=0.9,
use_nesterov=True)
opt_op = mom_op.minimize(cost, global_step, [var0, var1])
tf.global_variables_initializer().run()
for t in range(1, 5):
opt_op.run()
var0_np, accum0_np = self._update_nesterov_momentum_numpy(var0_np,
accum0_np, var0_np * 10, 2.0, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
accum1_np, 3, 2.0, 0.9)
self.assertAllClose(var0_np, var0.eval())
self.assertAllClose(var1_np, var1.eval())
评论列表
文章目录