def doRotationsSigmas(X, rotations, num_units):
with vs.variable_scope("Do_Rotations"):
sigma = vs.get_variable(
"Sigma", [num_units,1],
dtype=tf.float32,
initializer=init_ops.constant_initializer(value=1.0, dtype=tf.float32))
sigma_spot = int(len(rotations)/2)
for i, sparse_rot in enumerate(rotations):
if i == sigma_spot:
X = X * sigma
X = tf.sparse_tensor_dense_matmul(sparse_rot, X)
return X, sigma
评论列表
文章目录