helper.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:srnn-pytorch 作者: vvanirudh 项目源码 文件源码
def getCoef(outputs):
    '''
    Extracts the mean, standard deviation and correlation
    params:
    outputs : Output of the SRNN model
    '''
    mux, muy, sx, sy, corr = outputs[:, :, 0], outputs[:, :, 1], outputs[:, :, 2], outputs[:, :, 3], outputs[:, :, 4]

    # Exponential to get a positive value for std dev
    sx = torch.exp(sx)
    sy = torch.exp(sy)
    # tanh to get a value between [-1, 1] for correlation
    corr = torch.tanh(corr)

    return mux, muy, sx, sy, corr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号