utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:hred-latent-piecewise 作者: julianser 项目源码 文件源码
def BatchedDot(x, y, last_axis=False):
    if last_axis==False:
        return T.batched_dot(x, y)
    elif last_axis:
        if x.ndim == 2:
            shuffled_x = x.dimshuffle(1,0)
        elif x.ndim == 3:
            shuffled_x = x.dimshuffle(2,0,1)
        elif x.ndim == 4:
            shuffled_x = x.dimshuffle(3,0,1,2)
        else:
            raise ValueError('BatchedDot inputs must have between 2-4 dimensions, but x has ' + str(x.ndim) + ' dimensions')

        if y.ndim == 2:
            shuffled_y = y.dimshuffle(1,0)
        elif y.ndim == 3:
            shuffled_y = y.dimshuffle(2,0,1)
        elif y.ndim == 4:
            shuffled_y = y.dimshuffle(3,0,1,2)
        else:
            raise ValueError('BatchedDot inputs must have between 2-4 dimensions, but y has ' + str(y.ndim) + ' dimensions')

        dot = T.batched_dot(shuffled_x, shuffled_y)
        if dot.ndim == 2:
            return dot.dimshuffle(1,0)
        elif dot.ndim == 3:
            return dot.dimshuffle(1,2,0)
        elif dot.ndim == 4:
            return dot.dimshuffle(1,2,3,0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号