def clipped_error(x):
# Huber loss
try:
return tf.select(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5)
except:
return tf.where(tf.abs(x) < 1.0, 0.5 * tf.square(x), tf.abs(x) - 0.5)
# return 0.5 * tf.square(x)
评论列表
文章目录