new_data_mlp.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:NVDM-For-Document-Classification 作者: cryanzpj 项目源码 文件源码
def thres_search(data,label,n):
    res = []
    for i in range(n):
        n_label = tf.cast(tf.reduce_sum(label[i]),tf.int32)
        temp = tf.mul(data[i],label[i])
        temp = tf.reshape(tf.nn.top_k(temp,n_label +1).values,[1,1,-1,1])
        thres = tf.reshape(tf.contrib.layers.avg_pool2d(temp,[1,2],[1,1]),[-1,1])
        predicts = tf.map_fn(lambda x: tf.cast(tf.greater_equal(data[i],x),tf.float32),thres)
        f1_scores = tf.map_fn(lambda x: f1(x,label[i]),predicts)
        thres_opt = thres[tf.cast(tf.arg_max(f1_scores,0),tf.int32)]
        res.append(thres_opt)
        # R = tf.map_fn(lambda x: tf.contrib.metrics.streaming_recall(x,label[i])[0],predicts)
        # P = tf.map_fn(lambda x: tf.contrib.metrics.streaming_precision(x,label[i])[0],predicts)
        #thres_opt = thres[np.argsort(map(lambda x:  metrics.f1_score(x,sess.run(label[i]),average = "macro") ,predicts))[-1]]

    return tf.reshape(res,[-1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号