def split_to_complex(x, xshp=None):
xshp = x.get_shape().as_list() if xshp is None else xshp
if len(xshp) == 2:
assert xshp[1] % 2 == 0, \
"Vector is not evenly divisible into complex: %d" % xshp[1]
mid = xshp[1] / 2
return tf.complex(x[:, 0:mid], x[:, mid:])
else:
assert xshp[0] % 2 == 0, \
"Vector is not evenly divisible into complex: %d" % xshp[0]
mid = xshp[0] / 2
return tf.complex(x[0:mid], x[mid:])
评论列表
文章目录