def build_matchnet(self):
self.sentence_fc2 = self.sentencenet(self.tfidf_feat, reuse=False)
#self.sentence_fc2 = self.sentence_concat(self.tfidf_feat, self.lda_feat, reuse=False)
self.image_fc2 = self.imagenet(self.image_feat, skip=self.is_skip, reuse=False)
# compute loss
if self.is_training:
# triplet loss
#sentence_fc2_neg = self.sentencenet(self.sentence_feat_neg, reuse=True)
#image_fc2_neg = self.imagenet(self.image_feat_neg, skip=self.is_skip, reuse=True)
#self.image_center_triplet_loss = self.triplet_loss(self.image_fc2, self.sentence_fc2, sentence_fc2_neg)
#self.sentence_center_triplet_loss = self.triplet_loss(self.sentence_fc2, self.image_fc2, image_fc2_neg)
# top k triplet loss
self.sentence_center_triplet_loss, self.image_center_triplet_loss = self.top_K_loss(
self.sentence_fc2, self.image_fc2)
self.reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
# reg loss and total loss
self.total_loss = tf.add_n([self.image_center_triplet_loss, self.sentence_center_triplet_loss] + self.reg_loss)
self.saver = tf.train.Saver(max_to_keep=30)
self.t_var = tf.trainable_variables()
self.g_var = tf.global_variables()
self.img_var = [var for var in self.t_var if 'image' in var.name]
BidirectionNet_4wtfidf.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录