tf-keras-skeleton.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def fit(self, x, y, **kwargs):
            """Constructs a new model with `build_fn` & fit the model to `(x, y)`.

            Arguments:
                x : array-like, shape `(n_samples, n_features)`
                    Training samples where n_samples in the number of samples
                    and n_features is the number of features.
                y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
                    True labels for X.
                **kwargs: dictionary arguments
                    Legal arguments are the arguments of `Sequential.fit`

            Returns:
                history : object
                    details about the training history at each epoch.
            """
            if self.build_fn is None:
              self.model = self.__call__(**self.filter_sk_params(self.__call__))
            elif (not isinstance(self.build_fn, types.FunctionType) and
                  not isinstance(self.build_fn, types.MethodType)):
              self.model = self.build_fn(
                  **self.filter_sk_params(self.build_fn.__call__))
            else:
              self.model = self.build_fn(**self.filter_sk_params(self.build_fn))

            loss_name = self.model.loss
            if hasattr(loss_name, '__name__'):
              loss_name = loss_name.__name__
            if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
              y = to_categorical(y)

            fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
            fit_args.update(kwargs)

            history = self.model.fit(x, y, **fit_args)

            return history
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号