sequenceNet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deep-summarization 作者: harpribot 项目源码 文件源码
def store_test_predictions(self, prediction_id='_final'):
        """
        Stores the test predictions in a CSV file

        :param prediction_id: A simple id appended to the name of the summary for uniqueness
        :return: None
        """
        # prediction id is usually the step count
        print 'Storing predictions on Test Data...'
        review = []
        true_summary = []
        generated_summary = []
        for i in range(self.test_size):
            if not self.checkpointer.is_output_file_present():
                review.append(self._index2sentence(self.test_review[i]))
                true_summary.append(self._index2sentence(self.true_summary[i]))
            if i < (self.test_batch_size * (self.test_size // self.test_batch_size)):
                generated_summary.append(self._index2sentence(self.predicted_test_summary[i]))
            else:
                generated_summary.append('')

        prediction_nm = 'generated_summary' + prediction_id
        if self.checkpointer.is_output_file_present():
            df = pd.read_csv(self.checkpointer.get_result_location(), header=0)
            df[prediction_nm] = np.array(generated_summary)
        else:
            df = pd.DataFrame()
            df['review'] = np.array(review)
            df['true_summary'] = np.array(true_summary)
            df[prediction_nm] = np.array(generated_summary)
        df.to_csv(self.checkpointer.get_result_location(), index=False)
        print 'Stored the predictions. Moving Forward'
        if prediction_id == '_final':
            print 'All done. Exiting..'
            print 'Exited'
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号