train_lstm.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tdlstm 作者: bluemonk482 项目源码 文件源码
def eval(self, session, feed, saver, early_stopping_rounds, early_stopping_metric_list, early_stopping_metric_minimize=False, metrics='accuracy'):
        test_loss_value, acc_test, pred, eval_summary = session.run(self.test_loss, feed)
        f1_3class, f1_2class = fscores(self.data.dev_y, pred)
        if not self.tuning:
            print("*** Validation Loss = {:.6f}; Validation Accuracy = {:.5f}; 3-class F1 = {:.5f}; 2-class F1 = {:.5f}"
                        .format(test_loss_value, acc_test, f1_3class, f1_2class))
            print()
        early_stop = False
        early_stopping_score = -1
        if metrics == 'accuracy':
            early_stopping_score = acc_test
            early_stopping_metric_list.append(acc_test)
        elif metrics == '3classf1':
            early_stopping_score = f1_3class
            early_stopping_metric_list.append(f1_3class)
        elif metrics == '2classf1':
            early_stopping_score = f1_2class
            early_stopping_metric_list.append(f1_2class)
        assert early_stopping_score > 0

        if (not self.FLAGS.restore) and (early_stopping_metric_minimize): # For minimising the eval score
            if all(early_stopping_score <= i for i in early_stopping_metric_list):
                saver.save(session, self.FLAGS.checkpoint_file)
                best_eval_score = (acc_test, f1_3class, f1_2class)
            if early_stopping_metric_list[::-1].index(min(early_stopping_metric_list)) > early_stopping_rounds:
                early_stop = True
            return (test_loss_value, (acc_test, f1_3class, f1_2class), early_stop)
        elif not (self.FLAGS.restore and early_stopping_metric_minimize):  # For maximising the eval score
            if all(early_stopping_score >= i for i in early_stopping_metric_list):
                saver.save(session, self.FLAGS.checkpoint_file)
                best_eval_score = (acc_test, f1_3class, f1_2class)
            if early_stopping_metric_list[::-1].index(max(early_stopping_metric_list)) > early_stopping_rounds:
                early_stop = True
            return (test_loss_value, (acc_test, f1_3class, f1_2class), early_stop, eval_summary)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号