def op_unitary_loss(s_re_, s_im_, axes_=None, size_=None):
'''
unitary matrix loss of real/imag part,
used to regularize parameter to unitary
Args:
s_re_: real part, square matrix
s_im_: imag part, square matrix
size_: specify args to be (size_ x size_) matrices
axes_: tuple of two integers, specify which axes to be for matrix,
defaults to last two axes
'''
if axes_ is None:
axes_ = (-2, -1)
if size_ is None:
ax = axes_[0]
size = T.shape(s_re_)[ax]
else:
size = size_
assert s_re_.ndim == s_im_.ndim
tpat = list(range(s_re_.ndim))
bpat = ['x'] * s_re_.ndim
tpat[axes_[0]], tpat[axes_[1]] = tpat[axes_[1]], tpat[axes_[0]]
bpat[axes_[0]] = 0
bpat[axes_[1]] = 1
s_y_re_ = T.dot(s_re_.transpose(*tpat), s_re_) + T.dot(s_im_.transpose(*tpat), s_im_)
s_tmp = T.dot(s_re_.transpose(*tpat), s_im_)
s_y_im_ = s_tmp - s_tmp.transpose(*tpat)
return T.mean(T.sqr(s_y_re_ - T.eye(size).dimshuffle(*bpat)) + T.sqr(s_y_im_))
评论列表
文章目录