def thres_search(data,label,n):
res = []
for i in range(n):
n_label = tf.cast(tf.reduce_sum(label[i]),tf.int32)
temp = tf.mul(data[i],label[i])
temp = tf.reshape(tf.nn.top_k(temp,n_label +1).values,[1,1,-1,1])
thres = tf.reshape(tf.contrib.layers.avg_pool2d(temp,[1,2],[1,1]),[-1,1])
predicts = tf.map_fn(lambda x: tf.cast(tf.greater_equal(data[i],x),tf.float32),thres)
f1_scores = tf.map_fn(lambda x: f1(x,label[i]),predicts)
thres_opt = thres[tf.cast(tf.arg_max(f1_scores,0),tf.int32)]
res.append(thres_opt)
# R = tf.map_fn(lambda x: tf.contrib.metrics.streaming_recall(x,label[i])[0],predicts)
# P = tf.map_fn(lambda x: tf.contrib.metrics.streaming_precision(x,label[i])[0],predicts)
#thres_opt = thres[np.argsort(map(lambda x: metrics.f1_score(x,sess.run(label[i]),average = "macro") ,predicts))[-1]]
return tf.reshape(res,[-1])
new_data_mlp.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录