rotationTransform.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:dizzy_layer 作者: Pastromhaug 项目源码 文件源码
def rotationTransform(X, n, scope, num_rots=None):
    num_rots = num_rots or (n-1)
    n_prime = int(n*(n-1)//2*num_rots/(n-1))
    outputs = []

    with vs.variable_scope(scope or "RotationTransform"):

        for i, (name, x) in enumerate(X):
            (indices, values_idxs) = rotationPreprocess(n, num_rots)
            thetas = vs.get_variable(initializer=tf.random_uniform([n_prime, 1], 0, 2*math.pi),
                    name="Thetas"+str(i)+name, dtype=tf.float32)

            cos = tf.cos(thetas)
            sin = tf.sin(thetas)
            nsin = tf.neg(sin)

            thetas_concat = tf.concat(0, [cos,sin,nsin])

            gathered_values = tf.squeeze(tf.gather(thetas_concat, values_idxs))
            shape = tf.constant([n, n], dtype=tf.int64)

            splt_values = tf.split(0, num_rots, gathered_values)
            splt_indices = tf.split(0, num_rots, indices)

            shape = tf.constant([n,n], dtype=tf.int64)
            for i in range(num_rots):
                curr_indices = splt_indices[i]
                curr_values = splt_values[i]
                sparse_rot = tf.SparseTensor(indices=curr_indices, values=curr_values, shape=shape)
                x = tf.sparse_tensor_dense_matmul(sparse_rot, x)
            outputs.append(x)
    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号