def safe_exp(w, thresh):
"""Safe exponential function for tensors."""
slope = np.exp(thresh)
with tf.variable_scope('safe_exponential'):
lin_region = tf.to_float(w > thresh)
lin_out = slope*(w - thresh + 1.)
exp_out = tf.exp(w)
out = lin_region*lin_out + (1.-lin_region)*exp_out
return out
评论列表
文章目录