matchnn.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:MatchingNetwork 作者: cnichkawde 项目源码 文件源码
def call(self,inputs):
        """
        inputs in as array which contains the support set the embeddings, 
        the target embedding as the second last value in the array, and true class of target embedding as the last value in the array
        """ 
        similarities = []

        targetembedding = inputs[-2] # embedding of the query image
        numsupportset = len(inputs)-2
        for ii in range(numsupportset):
            supportembedding = inputs[ii] # embedding for i^{th} member in the support set

            sum_support = tf.reduce_sum(tf.square(supportembedding), 1, keep_dims=True)
            supportmagnitude = tf.rsqrt(tf.clip_by_value(sum_support, self.eps, float("inf"))) #reciprocal of the magnitude of the member of the support 

            sum_query = tf.reduce_sum(tf.square(targetembedding), 1, keep_dims=True)
            querymagnitude = tf.rsqrt(tf.clip_by_value(sum_query, self.eps, float("inf"))) #reciprocal of the magnitude of the query image

            dot_product = tf.matmul(tf.expand_dims(targetembedding,1),tf.expand_dims(supportembedding,2))
            dot_product = tf.squeeze(dot_product,[1])

            cosine_similarity = dot_product*supportmagnitude*querymagnitude
            similarities.append(cosine_similarity)

        similarities = tf.concat(axis=1,values=similarities)
        softmax_similarities = tf.nn.softmax(similarities)
        preds = tf.squeeze(tf.matmul(tf.expand_dims(softmax_similarities,1),inputs[-1]))

        preds.set_shape((inputs[0].shape[0],self.nway))

        return preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号