base.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _get_model_fn(self, model_fn):
    """Backward compatibility way of adding class weight and IS_TRAINING.

    TODO(ipolosukhin): Remove this function after new layers are available.
    Specifically:
     * dropout and batch norm should work via update ops.
     * class weights should be retrieved from weights column or hparams.

    Args:
      model_fn: Core model function.

    Returns:
      Model function.
    """
    def _model_fn(features, targets, mode):
      """Model function."""
      ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
      if self.class_weight is not None:
        constant_op.constant(self.class_weight, name='class_weight')
      predictions, loss = model_fn(features, targets)
      if isinstance(self.learning_rate, types.FunctionType):
        learning_rate = self.learning_rate(contrib_framework.get_global_step())
      else:
        learning_rate = self.learning_rate
      if isinstance(self.optimizer, types.FunctionType):
        optimizer = self.optimizer(learning_rate)
      else:
        optimizer = self.optimizer
      train_op = layers.optimize_loss(
          loss,
          contrib_framework.get_global_step(),
          learning_rate=learning_rate,
          optimizer=optimizer,
          clip_gradients=self.clip_gradients)
      return predictions, loss, train_op
    return _model_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号