def op_ortho_loss(s_x_, axes_=(-2, -1), ndim_=None):
'''
orthogoal matrix loss
used to regularize parameter to unitary
Args:
s_x_: (batch of) matrices
axes_: tuple of two integers, specify which axes to be for matrix,
defaults to last two axes
ndim_: specify args to be (ndim_ x ndim_) matrices
'''
if ndim_ is None:
ax = axes_[0]
ndim = T.shape(s_x_)[ax]
else:
ndim = ndim_
tpat = list(range(ndim))
bpat = ['x'] * s_x_.ndim
tpat[axes_[0]], tpat[axes_[1]] = tpat[axes_[1]], tpat[axes_[0]]
bpat[axes_[0]] = 0
bpat[axes_[1]] = 1
s_y = T.dot(s_x_.transpose(*tpat), s_x_)
return T.sqr(s_y - T.eye(ndim).dimshuffle(*bpat))
评论列表
文章目录