SDAE.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:HyPRec 作者: mostafa-mahmoud 项目源码 文件源码
def _train(self):
        """
        Train the stacked denoising autoencoders.
        """
        if 'fold' in self.hyperparameters:
            current_fold = self.hyperparameters['fold'] + 1
        else:
            current_fold = 0
        term_freq = self.abstracts_preprocessor.get_term_frequency_sparse_matrix().todense()
        self.get_cnn()
        if self._verbose:
            print("CNN is constructed...")
        error = numpy.inf
        iterations = 0
        batchsize = 2048
        for epoch in range(1, 1 + self.n_iter):
            self.document_distribution = self.predict_sdae(term_freq)
            t0 = time.time()
            self.user_vecs = self.als_step(self.user_vecs, self.item_vecs, self.train_data, self._lambda, type='user')
            self.item_vecs = self.als_step(self.item_vecs, self.user_vecs, self.train_data, self._lambda, type='item')
            t1 = time.time()
            iterations += 1
            if self._verbose:
                error = self.evaluator.get_rmse(self.user_vecs.dot(self.item_vecs.T), self.train_data)
                if current_fold == 0:
                    logs = dict(it=iterations, epoch=epoch, loss=error, time=(t1 - t0))
                    print('Iteration:{it:05d} Epoch:{epoch:02d} Loss:{loss:1.4e} Time:{time:.3f}s'.format(**logs))
                else:
                    logs = dict(fold=current_fold, it=iterations, epoch=epoch, loss=error, time=(t1 - t0))
                    print('Fold:{fold:02d} Iteration:{it:05d} Epoch:{epoch:02d} Loss:{loss:1.4e} '
                          'Time:{time:.3f}s'.format(**logs))

            for inp_batch, item_batch in chunks(batchsize, term_freq, self.item_vecs):
                t0 = time.time()
                loss = self.train_sdae(inp_batch, item_batch)
                t1 = time.time()
                iterations += 1
                if self._verbose:
                    if current_fold == 0:
                        msg = ('Iteration:{it:05d} Epoch:{epoch:02d} Loss:{loss:1.3e} Time:{tim:.3f}s')
                        logs = dict(loss=float(loss), epoch=epoch, it=iterations, tim=(t1 - t0))
                        print(msg.format(**logs))
                    else:
                        msg = ('Fold:{fold:02d} Iteration:{it:05d} Epoch:{epoch:02d} Loss:{loss:1.3e} Time:{tim:.3f}s')
                        logs = dict(fold=current_fold, loss=float(loss), epoch=epoch, it=iterations, tim=(t1 - t0))
                        print(msg.format(**logs))
            error = self.evaluator.get_rmse(self.user_vecs.dot(self.item_vecs.T), self.train_data)

        self.document_distribution = self.predict_sdae(term_freq)
        rms = self.evaluate_sdae(term_freq, self.item_vecs)

        if self._verbose:
            print(rms)
        # Garbage collection for keras
        backend.clear_session()
        if self._verbose:
            print("SDAE trained...")
        return rms
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号