def eval(self, session, feed, saver, early_stopping_rounds, early_stopping_metric_list, early_stopping_metric_minimize=False, metrics='accuracy'):
test_loss_value, acc_test, pred, eval_summary = session.run(self.test_loss, feed)
f1_3class, f1_2class = fscores(self.data.dev_y, pred)
if not self.tuning:
print("*** Validation Loss = {:.6f}; Validation Accuracy = {:.5f}; 3-class F1 = {:.5f}; 2-class F1 = {:.5f}"
.format(test_loss_value, acc_test, f1_3class, f1_2class))
print()
early_stop = False
early_stopping_score = -1
if metrics == 'accuracy':
early_stopping_score = acc_test
early_stopping_metric_list.append(acc_test)
elif metrics == '3classf1':
early_stopping_score = f1_3class
early_stopping_metric_list.append(f1_3class)
elif metrics == '2classf1':
early_stopping_score = f1_2class
early_stopping_metric_list.append(f1_2class)
assert early_stopping_score > 0
if (not self.FLAGS.restore) and (early_stopping_metric_minimize): # For minimising the eval score
if all(early_stopping_score <= i for i in early_stopping_metric_list):
saver.save(session, self.FLAGS.checkpoint_file)
best_eval_score = (acc_test, f1_3class, f1_2class)
if early_stopping_metric_list[::-1].index(min(early_stopping_metric_list)) > early_stopping_rounds:
early_stop = True
return (test_loss_value, (acc_test, f1_3class, f1_2class), early_stop)
elif not (self.FLAGS.restore and early_stopping_metric_minimize): # For maximising the eval score
if all(early_stopping_score >= i for i in early_stopping_metric_list):
saver.save(session, self.FLAGS.checkpoint_file)
best_eval_score = (acc_test, f1_3class, f1_2class)
if early_stopping_metric_list[::-1].index(max(early_stopping_metric_list)) > early_stopping_rounds:
early_stop = True
return (test_loss_value, (acc_test, f1_3class, f1_2class), early_stop, eval_summary)
评论列表
文章目录