def getCoef(outputs):
'''
Extracts the mean, standard deviation and correlation
params:
outputs : Output of the SRNN model
'''
mux, muy, sx, sy, corr = outputs[:, :, 0], outputs[:, :, 1], outputs[:, :, 2], outputs[:, :, 3], outputs[:, :, 4]
sx = torch.exp(sx)
sy = torch.exp(sy)
corr = torch.tanh(corr)
return mux, muy, sx, sy, corr
评论列表
文章目录