head.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def _combine_train(self, all_model_fn_ops, train_op_fn):
    """Combines list of ModelFnOps for training.

    Args:
      all_model_fn_ops: list of ModelFnOps for the individual heads.
      train_op_fn: Function to create train op. See `create_model_fn_ops`
          documentaion for more details.

    Returns:
      ModelFnOps that combines all the heads.
    """
    losses = []
    additional_train_ops = []
    for m in all_model_fn_ops:
      losses.append(m.loss)
      additional_train_ops.append(m.train_op)
    loss = self._loss_combiner(losses)

    train_op = train_op_fn(loss)
    train_op = control_flow_ops.group(train_op, *additional_train_ops)
    return model_fn.ModelFnOps(
        mode=model_fn.ModeKeys.TRAIN,
        loss=loss,
        train_op=train_op)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号