ntm.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ntm_keras 作者: flomlo 项目源码 文件源码
def _cosine_distance(M, k):
    # this is equation (6), or as I like to call it: The NaN factory.
    # TODO: Find it in a library (keras cosine loss?)
    # normalizing first as it is better conditioned.
    nk = K.l2_normalize(k, axis=-1)
    nM = K.l2_normalize(M, axis=-1)
    cosine_distance = K.batch_dot(nM, nk)
    # TODO: Do succesfull error handling
    #cosine_distance_error_handling = tf.Print(cosine_distance, [cosine_distance], message="NaN occured in _cosine_distance")
    #cosine_distance_error_handling = K.ones(cosine_distance_error_handling.shape)
    #cosine_distance = tf.case({K.any(tf.is_nan(cosine_distance)) : (lambda: cosine_distance_error_handling)},
    #        default = lambda: cosine_distance, strict=True)
    return cosine_distance
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号