learn_comb.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def learn_comb(poses, dm_shape, batch_size, max_length, n_dims, reuse=None, _float_type=tf.float32):
    with tf.variable_scope("learn_comb", reuse=reuse):
        comb_matrix = tf.get_variable(
            "matrix", [dm_shape[0], dm_shape[1]],
            initializer=identity_initializer(0.01),
            dtype=_float_type, trainable=True
        )
        norm_comb_matrix = comb_matrix / tf.reduce_sum(comb_matrix, axis=0, keep_dims=True)

        poses = tf.transpose(poses, [0, 1, 3, 2])
        poses = tf.reshape(poses, [batch_size * max_length * n_dims, dm_shape[0]])
        poses = tf.matmul(poses, norm_comb_matrix)
        poses = tf.reshape(poses, [batch_size, max_length, n_dims, dm_shape[0]])
        poses = tf.transpose(poses, [0, 1, 3, 2])
        poses = tf.reshape(poses, [batch_size, max_length, dm_shape[0], n_dims])

        cb_min = tf.reduce_min(norm_comb_matrix)
        cb_max = tf.reduce_max(norm_comb_matrix)
        comb_matrix_image = (norm_comb_matrix - cb_min) / (cb_max - cb_min) * 255.0
        comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8)
        comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1])
        return poses, comb_matrix_image
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号