def stable_softmax(y_hat):
"""Calculate softmax and log softmax in numerically stable way
Parameters
----------
y_hat : tensor3 (input_seq_len, num_batch, num_classes+1)
class energies
Return
------
softmax values in normal and log domain
"""
y_hat_safe = y_hat - y_hat.max(axis=2, keepdims=True)
y_hat_safe_exp = T.exp(y_hat_safe)
y_hat_safe_normalizer = y_hat_safe_exp.sum(axis=2, keepdims=True)
log_y_hat_safe_normalizer = T.log(y_hat_safe_normalizer)
y_hat_softmax = y_hat_safe_exp / y_hat_safe_normalizer
log_y_hat_softmax = y_hat_safe - log_y_hat_safe_normalizer
return y_hat_softmax, log_y_hat_softmax
评论列表
文章目录