def normalize(z): norm = tf.sqrt(tf.reduce_sum(tf.abs(z)**2)) factor = (norm + 1e-6) return tf.complex(tf.real(z) / factor, tf.imag(z) / factor) # z: complex[batch_sz, num_units] # bias: real[num_units]