def getCoef(outputs):
'''
Extracts the mean, standard deviation and correlation
params:
outputs : Output of the SRNN model
'''
mux, muy, sx, sy, corr = outputs[:, :, 0], outputs[:, :, 1], outputs[:, :, 2], outputs[:, :, 3], outputs[:, :, 4]
# Exponential to get a positive value for std dev
sx = torch.exp(sx)
sy = torch.exp(sy)
# tanh to get a value between [-1, 1] for correlation
corr = torch.tanh(corr)
return mux, muy, sx, sy, corr
评论列表
文章目录