MIL.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:image-text-matching 作者: llltttppp 项目源码 文件源码
def top_K_loss(self,sentence,image,K=30,margin=0.3):
        sim_matrix = tf.matmul(sentence, image,transpose_b=True)
        bs = tf.shape(sim_matrix)[0]
        s_square = tf.reduce_sum(tf.square(sentence),axis=1)
        im_square =tf.reduce_sum(tf.square(image),axis=1)
        d = tf.reshape(s_square,[-1,1])-2*sim_matrix+tf.reshape(im_square,[1,-1])
        positive = tf.stack([tf.matrix_diag_part(d)]*K,1)
        length = tf.shape(d)[-1]
        d = tf.matrix_set_diag(d, 100*tf.ones([length]))
        sen_loss_K ,_= tf.nn.top_k(-d,K,sorted=False)
        im_loss_K,_=tf.nn.top_k(tf.transpose(-d),K,sorted=False)
        sentence_center_loss = tf.nn.relu(sen_loss_K + positive +margin)
        image_center_loss = tf.nn.relu(im_loss_K + positive +margin)
        self.d_neg =tf.reduce_mean(-sen_loss_K-im_loss_K)/2
        self.d_pos = tf.reduce_mean(positive)
        self.endpoint['debug/sentence_center_loss']=sentence_center_loss
        self.endpoint['debug/image_center_loss']=image_center_loss
        self.endpoint['debug/sim_matrix']=sim_matrix
        self.endpoint['debug/sen_loss_K']=-sen_loss_K
        self.endpoint['debug/image_loss_K']=-im_loss_K
        self.endpoint['debug/distance']=d
        self.endpoint['debug/positive']=positive
        return tf.reduce_sum(sentence_center_loss),tf.reduce_sum(image_center_loss)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号