sim_matcher.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def train_model(self, train_anchor_batch, train_pos_batch, train_neg_batch,
                  model_params, train_params):
    # get embedding for all batches.
    all_batch = tf.concat(
        0, [train_anchor_batch, train_pos_batch, train_neg_batch])
    with tf.variable_scope("matcher"):
      all_feats, _ = self.build_model(all_batch, model_params)
      anchor_feats, pos_feats, neg_feats = tf.split(0, 3, all_feats)
    # compute loss.
    triplet_loss = dm_losses.triplet_loss(
        anchor_feats,
        pos_feats,
        neg_feats,
        0.2,
        loss_type=commons.LossType.TRIPLET_L2)
    tf.scalar_summary("losses/triplet_loss", triplet_loss)
    # run training.
    base_model.train_model_given_loss(triplet_loss, None, train_params)

  # TODO (jiefeng): use proper evaluation for matcher and test.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号