ops.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:dnc-theano 作者: khaotik 项目源码 文件源码
def op_ortho_loss(s_x_, axes_=(-2, -1), ndim_=None):
    '''
    orthogoal matrix loss
    used to regularize parameter to unitary

    Args:
        s_x_: (batch of) matrices
        axes_: tuple of two integers, specify which axes to be for matrix,
            defaults to last two axes
        ndim_: specify args to be (ndim_ x ndim_) matrices

    '''
    if ndim_ is None:
        ax = axes_[0]
        ndim = T.shape(s_x_)[ax]
    else:
        ndim = ndim_

    tpat = list(range(ndim))
    bpat = ['x'] * s_x_.ndim
    tpat[axes_[0]], tpat[axes_[1]] = tpat[axes_[1]], tpat[axes_[0]]
    bpat[axes_[0]] = 0
    bpat[axes_[1]] = 1
    s_y = T.dot(s_x_.transpose(*tpat), s_x_)
    return T.sqr(s_y - T.eye(ndim).dimshuffle(*bpat))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号