def get_model_fn(self):
def model_fn(features, labels, mode, params=None, config=None):
train_op = None
loss = None
eval_metrics = None
predictions = None
if mode == ModeKeys.TRAIN:
transformer_model = TransformerModule(params=self.model_params)
step = slim.get_or_create_global_step()
loss = transformer_model(features)
train_op = slim.optimize_loss(loss=loss,
global_step=step,
learning_rate=self.training_params["learning_rate"],
clip_gradients=self.training_params["clip_gradients"],
optimizer=params["optimizer"],
summaries=slim.OPTIMIZER_SUMMARIES
)
elif mode == ModeKeys.PREDICT:
raise NotImplementedError
elif mode == ModeKeys.EVAL:
transformer_model = TransformerModule(params=self.model_params)
loss = transformer_model(features)
return EstimatorSpec(train_op=train_op, loss=loss, eval_metric_ops=eval_metrics, predictions=predictions,
mode=mode)
return model_fn
评论列表
文章目录