def warp_loss(tf_prediction, tf_y, **kwargs):
# TODO JK: implement WARP loss
tf_positive_mask = tf.greater(tf_y, 0.0)
tf_negative_mask = tf.less_equal(tf_y, 0.0)
tf_positive_predictions = tf.boolean_mask(tf_prediction, tf_positive_mask) # noqa
tf_negative_predictions = tf.boolean_mask(tf_prediction, tf_negative_mask) # noqa
评论列表
文章目录