def learn_comb_orth_rmsprop(poses, dm_shape, reuse=None, _float_type=tf.float32):
with tf.variable_scope("learn_comb", reuse=reuse):
comb_matrix = tf.get_variable(
"matrix", [dm_shape[0], dm_shape[1]],
initializer=identity_initializer(0),
dtype=_float_type, trainable=False
)
comb_matrix_m = tf.get_variable(
"matrix_momentum", [dm_shape[0], dm_shape[1]],
initializer=tf.zeros_initializer(),
dtype=_float_type, trainable=False
)
tf.add_to_collection(COMB_MATRIX_COLLECTION, comb_matrix)
poses = tf.tensordot(poses, comb_matrix, [[2], [1]])
poses = tf.transpose(poses, [0, 1, 3, 2])
# Special update code
def update_comb_mat(grad, lr):
I = tf.constant(np.eye(dm_shape[0]), dtype=_float_type)
# Momentum update
momentum_op = tf.assign(comb_matrix_m,
comb_matrix_m * 0.99 + (1 - 0.99) * tf.square(grad))
with tf.control_dependencies([momentum_op]):
# Matrix update
scaled_grad = lr * grad / tf.sqrt(comb_matrix_m + 1.e-5)
A = tf.matmul(tf.transpose(scaled_grad), comb_matrix) - \
tf.matmul(tf.transpose(comb_matrix), scaled_grad)
t1 = I + 0.5 * A
t2 = I - 0.5 * A
Y = tf.matmul(tf.matmul(tf.matrix_inverse(t1), t2), comb_matrix)
return tf.assign(comb_matrix, Y)
# Visualization
cb_min = tf.reduce_min(comb_matrix)
cb_max = tf.reduce_max(comb_matrix)
comb_matrix_image = (comb_matrix - cb_min) / (cb_max - cb_min) * 255.0
comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8)
comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1])
return poses, comb_matrix_image, update_comb_mat
评论列表
文章目录