transformer.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:attention 作者: louishenrifranc 项目源码 文件源码
def get_model_fn(self):
        def model_fn(features, labels, mode, params=None, config=None):
            train_op = None
            loss = None
            eval_metrics = None
            predictions = None
            if mode == ModeKeys.TRAIN:
                transformer_model = TransformerModule(params=self.model_params)
                step = slim.get_or_create_global_step()
                loss = transformer_model(features)
                train_op = slim.optimize_loss(loss=loss,
                                              global_step=step,
                                              learning_rate=self.training_params["learning_rate"],
                                              clip_gradients=self.training_params["clip_gradients"],
                                              optimizer=params["optimizer"],
                                              summaries=slim.OPTIMIZER_SUMMARIES
                                              )
            elif mode == ModeKeys.PREDICT:
                raise NotImplementedError
            elif mode == ModeKeys.EVAL:
                transformer_model = TransformerModule(params=self.model_params)
                loss = transformer_model(features)

            return EstimatorSpec(train_op=train_op, loss=loss, eval_metric_ops=eval_metrics, predictions=predictions,
                                 mode=mode)

        return model_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号