def post(self):
if self.flags.task == "test_cnn_stage1":
docs = self.DB.clean_doc['test_text_filter']
elif self.flags.task == "test_cnn_stage2":
docs = self.DB.clean_doc['stage2_test_text']
else:
self.mDB.get_split()
docs = self.mDB.split[self.flags.fold][1]
nrows = len(docs)
p = np.zeros([nrows,9])
for i in range(self.flags.epochs):
if i==0:
skiprows=None
else:
skiprows = nrows*i
p = p + (pd.read_csv(self.flags.pred_path,header=None,nrows=nrows,skiprows=skiprows).values)
p = p/self.flags.epochs
if '_cv' in self.flags.task:
from utils.np_utils.utils import cross_entropy
y = np.argmax(self.mDB.y,axis=1)
print("cross entropy", cross_entropy(y[self.mDB.split[self.flags.fold][1]],p))
s = pd.DataFrame(p,columns=['class%d'%i for i in range(1,10)])
s['ID'] = np.arange(nrows)+1
s.to_csv(self.flags.pred_path.replace(".csv","_sub.csv"),index=False,float_format="%.5f")
评论列表
文章目录