FactorMachine.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:kaggle 作者: RankingAI 项目源码 文件源码
def train(self):
        """"""
        print('size before truncated outliers is %d ' % len(self.TrainData))
        #TrainData = self.TrainData[(self.TrainData['logerror'] > -0.4) & (self.TrainData['logerror'] < 0.418)]
        TrainData = self.TrainData
        print('size after truncated outliers is %d ' % len(TrainData))
        print('train data size %d' % len(TrainData))

        #self.__ExtraEncode()

        X = TrainData.drop(self._l_drop_cols, axis=1)
        Y = TrainData['logerror']
        l_train_columns = X.columns

        cols = []
        for col in l_train_columns:
            for cc in self._l_cate_cols:
                if (col.startswith('%s_' % cc)):
                    cols.append(col)
                    break

        tmp_cols = set(cols)
        if(len(tmp_cols) != len(cols)):
            print('!!!! cols duplicated .')

        self._l_train_columns = list(tmp_cols)

        X = scipy.sparse.csr_matrix(X[self._l_train_columns])
        self._model = als.FMRegression(n_iter= self._iter, init_stdev=0.1, rank= self._rank, l2_reg_w= self._reg_w, l2_reg_V= self._reg_v)
        self._model.fit(X, Y)

        print('training done.')

        self._f_eval_train_model = '{0}/{1}_{2}.pkl'.format(self.OutputDir, self.__class__.__name__,datetime.now().strftime('%Y%m%d-%H:%M:%S'))
        with open(self._f_eval_train_model,'wb') as o_file:
            pickle.dump(self._model,o_file,-1)
        o_file.close()

        self.TrainData = pd.concat([self.TrainData,self.ValidData[self.TrainData.columns]],ignore_index= True) ## ignore_index will reset the index or index will be overlaped

        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号