mmlwrapper.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:polara 作者: Evfro 项目源码 文件源码
def _parse_factors(self):
        model_data_path = self.saved_model_path
        model_params = pd.read_csv(model_data_path, skiprows=2, sep=' ',
                        header=None, names=['col1', 'col2', 'col3'])
        num_users = self.data.index.userid.training.new.max() + 1
        num_items = self.data.index.itemid.new.max() + 1

        nu, nf = model_params.iloc[0, :2].astype(np.int64)
        boundary = nu*nf+1
        ni = model_params.iloc[boundary, 0].astype(np.int64)

        users_factors = model_params.iloc[1:boundary, :]

        if model_params.shape[0] == ((nu+ni)*nf + 2): #no biases
            items_biases = None
            items_factors = model_params.iloc[(boundary+1):]
        elif model_params.shape[0] == ((nu+ni)*nf + ni + 3): #has biases
            items_biases = model_params.iloc[(boundary+1):(boundary+1+ni), 0].values
            items_factors = model_params.iloc[(boundary+2+ni):, :]
        else:
            NotImplementedError('{} data is not recognized.'.format(model_data_path))

        if self.positive_only:
            user_mapping = pd.read_csv(self.user_mapping_file, sep = '\t', header=None)
            item_mapping = pd.read_csv(self.item_mapping_file, sep = '\t', header=None)

            user_factors_full = self._remap_factors(user_mapping, users_factors, num_users, nf)
            item_factors_full = self._remap_factors(item_mapping, items_factors, num_items, nf)

            if items_biases is not None:
                bias_factors_full = np.zeros(num_items,)
                np.put(bias_factors_full, item_mapping.loc[:, 1].values, items_biases)
                self._items_biases = bias_factors_full
            else:
                self._items_biases = None

            self._users_factors = user_factors_full
            self._items_factors = item_factors_full
        else:
            self._users_factors = users_factors['col3'].values.reshape(nu, nf)
            self._items_factors = items_factors['col3'].values.reshape(ni, nf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号