def mul(self, z): return tf.transpose(tf.gather(tf.transpose(z), self.P)) # FFTs # z: complex[batch_sz, num_units]