def _get_model_fn(self, model_fn):
"""Backward compatibility way of adding class weight and IS_TRAINING.
TODO(ipolosukhin): Remove this function after new layers are available.
Specifically:
* dropout and batch norm should work via update ops.
* class weights should be retrieved from weights column or hparams.
Args:
model_fn: Core model function.
Returns:
Model function.
"""
def _model_fn(features, targets, mode):
"""Model function."""
ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train')
if self.class_weight is not None:
constant_op.constant(self.class_weight, name='class_weight')
predictions, loss = model_fn(features, targets)
if isinstance(self.learning_rate, types.FunctionType):
learning_rate = self.learning_rate(contrib_framework.get_global_step())
else:
learning_rate = self.learning_rate
if isinstance(self.optimizer, types.FunctionType):
optimizer = self.optimizer(learning_rate)
else:
optimizer = self.optimizer
train_op = layers.optimize_loss(
loss,
contrib_framework.get_global_step(),
learning_rate=learning_rate,
optimizer=optimizer,
clip_gradients=self.clip_gradients)
return predictions, loss, train_op
return _model_fn
评论列表
文章目录