def train(self,
train_anchor_batch,
train_pos_batch,
train_neg_batch,
train_params,
preprocessed=True):
"""Training process of the matcher.
Each input data should have same shape.
Args:
train_anchor_batch: anchor batch.
train_pos_batch: positive batch.
train_neg_batch: negative batch.
train_params: commons.TrainTestParams object.
preprocessed: if data has been preprocessed.
"""
self.check_dm_model_exist()
self.dm_model.use_graph()
# get embedding for all batches.
all_batches = tf.concat(
0, [train_anchor_batch, train_pos_batch, train_neg_batch])
if not preprocessed:
all_batches = self.dm_model.preprocess(all_batches)
all_feats, _ = self.build_model(all_batches)
anchor_feats, pos_feats, neg_feats = tf.split(all_feats, 3, axis=0)
self.set_key_vars(train_params.restore_scopes_exclude,
train_params.train_scopes)
self.compute_losses(anchor_feats, pos_feats, neg_feats, train_params)
init_fn = None
if train_params.fine_tune:
# self.vars_to_restore is supposed to be set in set_key_vars
print("[dm_matcher.train: info] Trying to restore variables: {}".format(
self.vars_to_restore))
init_fn = slim.assign_from_checkpoint_fn(train_params.custom["model_fn"],
self.vars_to_restore)
if not train_params.resume_training:
data_manager.remove_dir(train_params.train_log_dir)
if train_params.use_regularization:
regularization_loss = tf.add_n(tf.losses.get_regularization_losses())
tf.summary.scalar("losses/regularization_loss", regularization_loss)
total_loss = tf.losses.get_total_loss(
add_regularization_losses=train_params.use_regularization)
base_model.train_model_given_loss(
total_loss, self.vars_to_train, train_params, init_fn=init_fn)
# TODO(jiefeng): to load weights from file.
评论列表
文章目录