ops.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:dnc-theano 作者: khaotik 项目源码 文件源码
def op_unitary_loss(s_re_, s_im_, axes_=None, size_=None):
    '''
    unitary matrix loss of real/imag part,
    used to regularize parameter to unitary

    Args:
        s_re_: real part, square matrix
        s_im_: imag part, square matrix
        size_: specify args to be (size_ x size_) matrices
        axes_: tuple of two integers, specify which axes to be for matrix,
            defaults to last two axes
    '''
    if axes_ is None:
        axes_ = (-2, -1)

    if size_ is None:
        ax = axes_[0]
        size = T.shape(s_re_)[ax]
    else:
        size = size_

    assert s_re_.ndim == s_im_.ndim

    tpat = list(range(s_re_.ndim))
    bpat = ['x'] * s_re_.ndim
    tpat[axes_[0]], tpat[axes_[1]] = tpat[axes_[1]], tpat[axes_[0]]
    bpat[axes_[0]] = 0
    bpat[axes_[1]] = 1
    s_y_re_ = T.dot(s_re_.transpose(*tpat), s_re_) + T.dot(s_im_.transpose(*tpat), s_im_)
    s_tmp = T.dot(s_re_.transpose(*tpat), s_im_)
    s_y_im_ = s_tmp - s_tmp.transpose(*tpat)
    return T.mean(T.sqr(s_y_re_ - T.eye(size).dimshuffle(*bpat)) + T.sqr(s_y_im_))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号