def train(self,
train_input_batch,
train_label_batch,
train_params,
preprocessed=True):
"""Training process of the classifier.
Args:
train_input_batch: input batch for training.
train_label_batch: class id for training.
train_params: commons.TrainTestParams object.
preprocessed: if train data has been preprocessed.
"""
assert train_input_batch is not None, "train input batch is none"
assert train_label_batch is not None, "train label batch is none"
assert isinstance(
train_params,
commons.TrainTestParams), "train params is not a valid type"
self.check_dm_model_exist()
# self.dm_model.use_graph()
model_params = self.dm_model.net_params
if not preprocessed:
train_input_batch = self.dm_model.preprocess(train_input_batch)
pred_logits, endpoints = self.build_model(train_input_batch)
self.set_key_vars(train_params.restore_scopes_exclude,
train_params.train_scopes)
comp_train_accuracy(pred_logits, train_label_batch)
tf.assert_equal(
tf.reduce_max(train_label_batch),
tf.convert_to_tensor(
model_params.cls_num, dtype=tf.int64))
onehot_labels = tf.one_hot(
train_label_batch, model_params.cls_num, on_value=1.0, off_value=0.0)
# onehot_labels = slim.one_hot_encoding(train_label_batch,
# model_params.cls_num)
onehot_labels = tf.squeeze(onehot_labels)
self.compute_losses(onehot_labels, pred_logits, endpoints)
init_fn = None
if train_params.fine_tune and not train_params.resume_training:
init_fn = slim.assign_from_checkpoint_fn(train_params.custom["model_fn"],
self.vars_to_restore)
# this would not work if a tensorboard is running...
if not train_params.resume_training:
data_manager.remove_dir(train_params.train_log_dir)
# display regularization loss.
if train_params.use_regularization:
regularization_loss = tf.add_n(tf.losses.get_regularization_losses())
tf.summary.scalar("losses/regularization_loss", regularization_loss)
total_loss = tf.losses.get_total_loss(
add_regularization_losses=train_params.use_regularization)
base_model.train_model_given_loss(
total_loss, self.vars_to_train, train_params, init_fn=init_fn)
评论列表
文章目录