def learn_comb(poses, dm_shape, batch_size, max_length, n_dims, reuse=None, _float_type=tf.float32):
with tf.variable_scope("learn_comb", reuse=reuse):
comb_matrix = tf.get_variable(
"matrix", [dm_shape[0], dm_shape[1]],
initializer=identity_initializer(0.01),
dtype=_float_type, trainable=True
)
norm_comb_matrix = comb_matrix / tf.reduce_sum(comb_matrix, axis=0, keep_dims=True)
poses = tf.transpose(poses, [0, 1, 3, 2])
poses = tf.reshape(poses, [batch_size * max_length * n_dims, dm_shape[0]])
poses = tf.matmul(poses, norm_comb_matrix)
poses = tf.reshape(poses, [batch_size, max_length, n_dims, dm_shape[0]])
poses = tf.transpose(poses, [0, 1, 3, 2])
poses = tf.reshape(poses, [batch_size, max_length, dm_shape[0], n_dims])
cb_min = tf.reduce_min(norm_comb_matrix)
cb_max = tf.reduce_max(norm_comb_matrix)
comb_matrix_image = (norm_comb_matrix - cb_min) / (cb_max - cb_min) * 255.0
comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8)
comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1])
return poses, comb_matrix_image
评论列表
文章目录