def _combine_train(self, all_model_fn_ops, train_op_fn):
"""Combines list of ModelFnOps for training.
Args:
all_model_fn_ops: list of ModelFnOps for the individual heads.
train_op_fn: Function to create train op. See `create_model_fn_ops`
documentaion for more details.
Returns:
ModelFnOps that combines all the heads.
"""
losses = []
additional_train_ops = []
for m in all_model_fn_ops:
losses.append(m.loss)
additional_train_ops.append(m.train_op)
loss = self._loss_combiner(losses)
train_op = train_op_fn(loss)
train_op = control_flow_ops.group(train_op, *additional_train_ops)
return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.TRAIN,
loss=loss,
train_op=train_op)
head.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录