dm_classifier.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:deepmodels 作者: learningsociety 项目源码 文件源码
def train(self,
            train_input_batch,
            train_label_batch,
            train_params,
            preprocessed=True):
    """Training process of the classifier.

    Args:
      train_input_batch: input batch for training.
      train_label_batch: class id for training.
      train_params: commons.TrainTestParams object.
      preprocessed: if train data has been preprocessed.
    """
    assert train_input_batch is not None, "train input batch is none"
    assert train_label_batch is not None, "train label batch is none"
    assert isinstance(
        train_params,
        commons.TrainTestParams), "train params is not a valid type"
    self.check_dm_model_exist()
    # self.dm_model.use_graph()
    model_params = self.dm_model.net_params
    if not preprocessed:
      train_input_batch = self.dm_model.preprocess(train_input_batch)
    pred_logits, endpoints = self.build_model(train_input_batch)
    self.set_key_vars(train_params.restore_scopes_exclude,
                      train_params.train_scopes)
    comp_train_accuracy(pred_logits, train_label_batch)
    tf.assert_equal(
        tf.reduce_max(train_label_batch),
        tf.convert_to_tensor(
            model_params.cls_num, dtype=tf.int64))
    onehot_labels = tf.one_hot(
        train_label_batch, model_params.cls_num, on_value=1.0, off_value=0.0)
    # onehot_labels = slim.one_hot_encoding(train_label_batch,
    #                                       model_params.cls_num)
    onehot_labels = tf.squeeze(onehot_labels)
    self.compute_losses(onehot_labels, pred_logits, endpoints)
    init_fn = None
    if train_params.fine_tune and not train_params.resume_training:
      init_fn = slim.assign_from_checkpoint_fn(train_params.custom["model_fn"],
                                               self.vars_to_restore)
    # this would not work if a tensorboard is running...
    if not train_params.resume_training:
      data_manager.remove_dir(train_params.train_log_dir)
    # display regularization loss.
    if train_params.use_regularization:
      regularization_loss = tf.add_n(tf.losses.get_regularization_losses())
      tf.summary.scalar("losses/regularization_loss", regularization_loss)

    total_loss = tf.losses.get_total_loss(
        add_regularization_losses=train_params.use_regularization)
    base_model.train_model_given_loss(
        total_loss, self.vars_to_train, train_params, init_fn=init_fn)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号