cnn.py 文件源码

python
阅读 54 收藏 0 点赞 0 评论 0

项目:kaggle-review 作者: daxiongshu 项目源码 文件源码
def post(self):
        if self.flags.task == "test_cnn_stage1":
            docs = self.DB.clean_doc['test_text_filter']
        elif self.flags.task == "test_cnn_stage2":
            docs = self.DB.clean_doc['stage2_test_text']
        else:
            self.mDB.get_split()
            docs = self.mDB.split[self.flags.fold][1]
        nrows = len(docs)
        p = np.zeros([nrows,9])
        for i in range(self.flags.epochs):
            if i==0:
                skiprows=None
            else:
                skiprows = nrows*i
            p = p + (pd.read_csv(self.flags.pred_path,header=None,nrows=nrows,skiprows=skiprows).values)
        p = p/self.flags.epochs
        if '_cv' in self.flags.task:
            from utils.np_utils.utils import cross_entropy
            y = np.argmax(self.mDB.y,axis=1)
            print("cross entropy", cross_entropy(y[self.mDB.split[self.flags.fold][1]],p))
        s = pd.DataFrame(p,columns=['class%d'%i for i in range(1,10)])
        s['ID'] = np.arange(nrows)+1
        s.to_csv(self.flags.pred_path.replace(".csv","_sub.csv"),index=False,float_format="%.5f")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号