def top_K_loss(self, sentence, image, K=50, margin=0.5, img_input_feat=None, text_input_feat=None):
sim_matrix = tf.matmul(sentence, image, transpose_b=True)
s_square = tf.reduce_sum(tf.square(sentence), axis=1)
im_square = tf.reduce_sum(tf.square(image), axis=1)
d = tf.reshape(s_square,[-1,1]) - 2 * sim_matrix + tf.reshape(im_square, [1, -1])
positive = tf.stack([tf.matrix_diag_part(d)] * K, axis=1)
length = tf.shape(d)[-1]
d = tf.matrix_set_diag(d, 8 * tf.ones([length]))
if img_input_feat is not None:
img_input_norm1 = img_input_feat / tf.norm(img_input_feat, axis=-1, keep_dims=True)
S_input_img = tf.matmul(img_input_norm1, img_input_norm1, transpose_b=True)
img_coeff = 8 - 7 * tf.sign(tf.nn.relu(0.99 - S_input_img))
sen_loss_K ,_ = tf.nn.top_k(-1.0 * d * img_coeff, K, sorted=False) # note: this is negative value
self.endpoint['debug/S_input_img'] = S_input_img
self.endpoint['debug/img_coeff'] = img_coeff
else:
sen_loss_K ,_ = tf.nn.top_k(-1.0 * d, K, sorted=False) # note: this is negative value
if text_input_feat is not None:
text_input_norm1 = text_input_feat / (tf.norm(text_input_feat, axis=-1, keep_dims=True) + 1e-10)
S_input_text = tf.matmul(text_input_norm1, text_input_norm1, transpose_b=True)
text_coeff = 8 - 7 * tf.sign(tf.nn.relu(0.98 - S_input_text))
im_loss_K,_ = tf.nn.top_k(tf.transpose(-1.0 * d * text_coeff), K, sorted=False)
self.endpoint['debug/S_input_text'] = S_input_text
self.endpoint['debug/text_coeff'] = text_coeff
else:
im_loss_K,_ = tf.nn.top_k(tf.transpose(-1.0 * d), K, sorted=False) # note: this is negative value
sentence_center_loss = tf.nn.relu(positive + sen_loss_K + margin)
image_center_loss = tf.nn.relu(positive + im_loss_K + margin)
self.d_neg = (sen_loss_K + im_loss_K)/-2.0
self.d_pos = positive
self.endpoint['debug/im_loss_topK'] = -1.0 * im_loss_K
self.endpoint['debug/sen_loss_topK'] = -1.0 * sen_loss_K
self.endpoint['debug/d_Matrix'] = d
self.endpoint['debug/positive'] = positive
self.endpoint['debug/s_center_loss'] = sentence_center_loss
self.endpoint['debug/i_center_loss'] = image_center_loss
self.endpoint['debug/S'] = sim_matrix
self.endpoint['debug/sentence_square'] = s_square
self.endpoint['debug/image_square'] = im_square
return tf.reduce_sum(sentence_center_loss), tf.reduce_sum(image_center_loss)
BidirectionNet_4wtfidf.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录