helper.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:social-lstm-pytorch 作者: vvanirudh 项目源码 文件源码
def getCoef(outputs):
    '''
    Extracts the mean, standard deviation and correlation
    params:
    outputs : Output of the SRNN model
    '''
    mux, muy, sx, sy, corr = outputs[:, :, 0], outputs[:, :, 1], outputs[:, :, 2], outputs[:, :, 3], outputs[:, :, 4]

    sx = torch.exp(sx)
    sy = torch.exp(sy)
    corr = torch.tanh(corr)
    return mux, muy, sx, sy, corr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号