matchnn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:MatchingNetwork 作者: cnichkawde 项目源码 文件源码
def call(self,inputs):
        """
        inputs in as array which contains the support set the embeddings, the target embedding as the second last value in the array, and true class of target embedding as the last value in the array
        """ 
        similarities = []

        targetembedding = inputs[-2]
        numsupportset = len(inputs)-2
        for ii in range(numsupportset):
            supportembedding = inputs[ii]
            dd = tf.negative(tf.sqrt(tf.reduce_sum(tf.square(supportembedding-targetembedding),1,keep_dims=True)))

            similarities.append(dd)

        similarities = tf.concat(axis=1,values=similarities)
        softmax_similarities = tf.nn.softmax(similarities)
        preds = tf.squeeze(tf.matmul(tf.expand_dims(softmax_similarities,1),inputs[-1]))

        preds.set_shape((inputs[0].shape[0],self.nway))

        return preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号