def train_model(self, train_anchor_batch, train_pos_batch, train_neg_batch,
model_params, train_params):
# get embedding for all batches.
all_batch = tf.concat(
0, [train_anchor_batch, train_pos_batch, train_neg_batch])
with tf.variable_scope("matcher"):
all_feats, _ = self.build_model(all_batch, model_params)
anchor_feats, pos_feats, neg_feats = tf.split(0, 3, all_feats)
# compute loss.
triplet_loss = dm_losses.triplet_loss(
anchor_feats,
pos_feats,
neg_feats,
0.2,
loss_type=commons.LossType.TRIPLET_L2)
tf.scalar_summary("losses/triplet_loss", triplet_loss)
# run training.
base_model.train_model_given_loss(triplet_loss, None, train_params)
# TODO (jiefeng): use proper evaluation for matcher and test.
评论列表
文章目录