roc_auc.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:deep-mil-for-whole-mammogram-classification 作者: wentaozhu 项目源码 文件源码
def on_epoch_end(self, epoch, logs={}):
    if epoch % self.interval == 0:
      y_pred = self.model.predict(self.X_val, verbose=0)
      if self.mymil:
        y_true = self.y_val.max(axis=1)
        y_score = y_pred.max(axis=1)>0.5
      else:
        y_true = np.argmax(self.y_val, axis=1)
        y_score = np.argmax(y_pred, axis=1)
      #print(type(y_true), y_true.shape, type(y_score), y_score.shape)
      TP = np.sum(y_true[y_score==1]==1)*1. #/ sum(y_true)
      FN = np.sum(y_true[y_score==0]==1)*1. #/ sum(y_true)
      reca = TP / (TP+FN+1e-6)
      print("interval evaluation - epoch: {:d} - reca: {:.2f}".format(epoch, reca))
      if reca > self.reca:
        self.reca = reca
        for f in os.listdir('./'):
          if f.startswith(self.filepath+'reca'):
            os.remove(f)
        self.model.save(self.filepath+'reca'+str(reca)+'ep'+str(epoch)+'.hdf5')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号