validation_script.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:devise-keras 作者: priyamtejaswin 项目源码 文件源码
def custom_for_keras(self, ALL_word_embeds):
        ## only the top 20 rows from word_vectors is legit!
        def top_accuracy(true_word_indices, image_vectors):
            l2 = lambda x, axis: K.sqrt(K.sum(K.square(x), axis=axis, keepdims=True))
            l2norm = lambda x, axis: x/l2(x, axis)

            l2_words = l2norm(ALL_word_embeds, axis=1)
            l2_images = l2norm(image_vectors, axis=1)

            tiled_words = K.tile(K.expand_dims(l2_words, axis=1) , (1, 200, 1))
            tiled_images = K.tile(K.expand_dims(l2_images, axis=1), (1, 20, 1))

            diff = K.squeeze(l2(l2_words - l2_images, axis=2))

            # slice_top3 = lambda x: x[:, 0:3]
            # slice_top1 = lambda x: x[:, 0:1]

            diff_top5 = metrics.top_k_categorical_accuracy(tiled_images, diff)
            return diff_top5

        return top_accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号