dm_matcher.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def train(self,
            train_anchor_batch,
            train_pos_batch,
            train_neg_batch,
            train_params,
            preprocessed=True):
    """Training process of the matcher.

    Each input data should have same shape.

    Args:
      train_anchor_batch: anchor batch.
      train_pos_batch: positive batch.
      train_neg_batch: negative batch.
      train_params: commons.TrainTestParams object.
      preprocessed: if data has been preprocessed.
    """
    self.check_dm_model_exist()
    self.dm_model.use_graph()
    # get embedding for all batches.
    all_batches = tf.concat(
        0, [train_anchor_batch, train_pos_batch, train_neg_batch])
    if not preprocessed:
      all_batches = self.dm_model.preprocess(all_batches)
    all_feats, _ = self.build_model(all_batches)
    anchor_feats, pos_feats, neg_feats = tf.split(all_feats, 3, axis=0)
    self.set_key_vars(train_params.restore_scopes_exclude,
                      train_params.train_scopes)
    self.compute_losses(anchor_feats, pos_feats, neg_feats, train_params)
    init_fn = None
    if train_params.fine_tune:
      # self.vars_to_restore is supposed to be set in set_key_vars
      print("[dm_matcher.train: info] Trying to restore variables: {}".format(
          self.vars_to_restore))
      init_fn = slim.assign_from_checkpoint_fn(train_params.custom["model_fn"],
                                               self.vars_to_restore)
    if not train_params.resume_training:
      data_manager.remove_dir(train_params.train_log_dir)
    if train_params.use_regularization:
      regularization_loss = tf.add_n(tf.losses.get_regularization_losses())
      tf.summary.scalar("losses/regularization_loss", regularization_loss)

    total_loss = tf.losses.get_total_loss(
        add_regularization_losses=train_params.use_regularization)
    base_model.train_model_given_loss(
        total_loss, self.vars_to_train, train_params, init_fn=init_fn)

  # TODO(jiefeng): to load weights from file.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号