def _cosine_distance(M, k):
# this is equation (6), or as I like to call it: The NaN factory.
# TODO: Find it in a library (keras cosine loss?)
# normalizing first as it is better conditioned.
nk = K.l2_normalize(k, axis=-1)
nM = K.l2_normalize(M, axis=-1)
cosine_distance = K.batch_dot(nM, nk)
# TODO: Do succesfull error handling
#cosine_distance_error_handling = tf.Print(cosine_distance, [cosine_distance], message="NaN occured in _cosine_distance")
#cosine_distance_error_handling = K.ones(cosine_distance_error_handling.shape)
#cosine_distance = tf.case({K.any(tf.is_nan(cosine_distance)) : (lambda: cosine_distance_error_handling)},
# default = lambda: cosine_distance, strict=True)
return cosine_distance
评论列表
文章目录