buildRotations.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dizzy_layer 作者: Pastromhaug 项目源码 文件源码
def buildRotations(n, rand_or_identity,num_rots=None):
    print("num_rots: %d" %num_rots)
    num_rots = num_rots or (n-1)
    n_prime = int(n*(n-1)//2*num_rots/(n-1))
    outputs = []

    with vs.variable_scope("Build_Rotations"):

        (indices, values_idxs) = rotationPreprocess(n, num_rots)
        if rand_or_identity:
            print("Initialization: Random")
            thetas = vs.get_variable(initializer=tf.random_uniform([n_prime, 1], 0, 2*math.pi),
                    name="Thetas_RandInit", dtype=tf.float32)
        else:
            print("Initialization: Identity")
            thetas = vs.get_variable(initializer=tf.zeros([n_prime, 1]),
                    name="Thetas_OnesInit", dtype=tf.float32)
        cos = tf.cos(thetas)
        sin = tf.sin(thetas)
        nsin = tf.neg(sin)

        thetas_concat = tf.concat(0, [cos,sin,nsin])

        gathered_values = tf.squeeze(tf.gather(thetas_concat, values_idxs))
        shape = tf.constant([n, n], dtype=tf.int64)

        splt_values = tf.split(0, num_rots, gathered_values)
        splt_indices = tf.split(0, num_rots, indices)

        shape = tf.constant([n,n], dtype=tf.int64)
        for i in range(num_rots):
            curr_indices = splt_indices[i]
            curr_values = splt_values[i]
            sparse_rot = tf.SparseTensor(indices=curr_indices, values=curr_values, shape=shape)
            outputs.append(sparse_rot)
    print("buildRotations output length: %d" % len(outputs))
    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号