def top_K_loss(self,sentence,image,K=30,margin=0.3):
sim_matrix = tf.matmul(sentence, image,transpose_b=True)
bs = tf.shape(sim_matrix)[0]
s_square = tf.reduce_sum(tf.square(sentence),axis=1)
im_square =tf.reduce_sum(tf.square(image),axis=1)
d = tf.reshape(s_square,[-1,1])-2*sim_matrix+tf.reshape(im_square,[1,-1])
positive = tf.stack([tf.matrix_diag_part(d)]*K,1)
length = tf.shape(d)[-1]
d = tf.matrix_set_diag(d, 100*tf.ones([length]))
sen_loss_K ,_= tf.nn.top_k(-d,K,sorted=False)
im_loss_K,_=tf.nn.top_k(tf.transpose(-d),K,sorted=False)
sentence_center_loss = tf.nn.relu(sen_loss_K + positive +margin)
image_center_loss = tf.nn.relu(im_loss_K + positive +margin)
self.d_neg =tf.reduce_mean(-sen_loss_K-im_loss_K)/2
self.d_pos = tf.reduce_mean(positive)
self.endpoint['debug/sentence_center_loss']=sentence_center_loss
self.endpoint['debug/image_center_loss']=image_center_loss
self.endpoint['debug/sim_matrix']=sim_matrix
self.endpoint['debug/sen_loss_K']=-sen_loss_K
self.endpoint['debug/image_loss_K']=-im_loss_K
self.endpoint['debug/distance']=d
self.endpoint['debug/positive']=positive
return tf.reduce_sum(sentence_center_loss),tf.reduce_sum(image_center_loss)
评论列表
文章目录