def testWithGlobalStep(self):
for dtype in [tf.complex64]:
with self.test_session(force_gpu=True):
with tf.device('/cpu'):
global_step = tf.Variable(0, trainable=False)
v0 = [1.0+2.0j, 2.0+1.0j]
v1 = [3.0-4.0j, 4.0-3.0j]
g0 = [0.1+0.1j, 0.1-0.1j]
g1 = [0.01-0.01j, 0.01+0.01j]
lr = 3.0-1.5j
var0 = tf.Variable(v0, dtype=dtype)
var1 = tf.Variable(v1, dtype=dtype)
grads0 = tf.constant(g0, dtype=dtype)
grads1 = tf.constant(g1, dtype=dtype)
sgd_op = ctf.train.CplxGradientDescentOptimizer(lr).apply_gradients(
zip([grads0, grads1], [var0, var1]),
global_step=global_step)
tf.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllCloseAccordingToType(v0, var0.eval())
self.assertAllCloseAccordingToType(v1, var1.eval())
# Run 1 step of sgd
sgd_op.run()
# Validate updated params and global_step
self.assertAllCloseAccordingToType(
[v0[0] - lr * g0[0],
v0[1] - lr * g0[1]], var0.eval())
self.assertAllCloseAccordingToType(
[v1[0] - lr * g1[0],
v1[1] - lr * g1[1]], var1.eval())
self.assertAllCloseAccordingToType(1, global_step.eval())
### Currently no support for sparse complex tensors
# def testSparseBasic(self):
# for dtype in [tf.complex64]:
# with self.test_session(force_gpu=True):
# var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
# var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)
# grads0 = tf.IndexedSlices(
# tf.constant([0.1], shape=[1, 1], dtype=dtype),
# tf.constant([0]),
# tf.constant([2, 1]))
# grads1 = tf.IndexedSlices(
# tf.constant([0.01], shape=[1, 1], dtype=dtype),
# tf.constant([1]),
# tf.constant([2, 1]))
# sgd_op = ctf.train.CplxGradientDescentOptimizer(3.0).apply_gradients(
# zip([grads0, grads1], [var0, var1]))
# tf.initialize_all_variables().run()
# # Fetch params to validate initial values
# self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
# self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
# # Run 1 step of sgd
# sgd_op.run()
# # Validate updated params
# self.assertAllCloseAccordingToType(
# [[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
# self.assertAllCloseAccordingToType(
# [[3.0], [4.0 - 3.0 * 0.01]], var1.eval())
评论列表
文章目录