factorization_machine.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:polylearn 作者: scikit-learn-contrib 项目源码 文件源码
def fit(self, X, y):
        """Fit factorization machine to training data.

        Parameters
        ----------
        X : array-like or sparse, shape = [n_samples, n_features]
            Training vectors, where n_samples is the number of samples
            and n_features is the number of features.

        y : array-like, shape = [n_samples]
            Target values.

        Returns
        -------
        self : Estimator
            Returns self.
        """
        if self.degree > 3:
            raise ValueError("FMs with degree >3 not yet supported.")

        X, y = self._check_X_y(X, y)
        X = self._augment(X)
        n_features = X.shape[1]  # augmented
        X_col_norms = row_norms(X.T, squared=True)
        dataset = get_dataset(X, order="fortran")
        rng = check_random_state(self.random_state)
        loss_obj = self._get_loss(self.loss)

        if not (self.warm_start and hasattr(self, 'w_')):
            self.w_ = np.zeros(n_features, dtype=np.double)

        if self.fit_lower == 'explicit':
            n_orders = self.degree - 1
        else:
            n_orders = 1

        if not (self.warm_start and hasattr(self, 'P_')):
            self.P_ = 0.01 * rng.randn(n_orders, self.n_components, n_features)

        if not (self.warm_start and hasattr(self, 'lams_')):
            if self.init_lambdas == 'ones':
                self.lams_ = np.ones(self.n_components)
            elif self.init_lambdas == 'random_signs':
                self.lams_ = np.sign(rng.randn(self.n_components))
            else:
                raise ValueError("Lambdas must be initialized as ones "
                                 "(init_lambdas='ones') or as random "
                                 "+/- 1 (init_lambdas='random_signs').")

        y_pred = self._get_output(X)

        converged, self.n_iter_ = _cd_direct_ho(
            self.P_, self.w_, dataset, X_col_norms, y, y_pred,
            self.lams_, self.degree, self.alpha, self.beta, self.fit_linear,
            self.fit_lower == 'explicit', loss_obj, self.max_iter,
            self.tol, self.verbose)
        if not converged:
            warnings.warn("Objective did not converge. Increase max_iter.")

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号