a2_predict_classification.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def get_label_using_logits_batch(question_id_sublist, logits_batch, vocabulary_index2word_label, f, top_number=5):
    print("get_label_using_logits.shape:", np.array(logits_batch).shape) # (1, 128, 2002))
    for i, logits in enumerate(logits_batch):
        index_list = np.argsort(logits)[-top_number:]
        #print("index_list:",index_list)
        index_list = index_list[::-1]
        label_list = []
        for index in index_list:
            #print("index:",index)
            label = vocabulary_index2word_label[index]
            label_list.append(
                label)  # ('get_label_using_logits.label_list:', [u'-3423450385060590478', u'2838091149470021485', u'-3174907002942471215', u'-1812694399780494968', u'6815248286057533876'])
        # print("get_label_using_logits.label_list",label_list)
        write_question_id_with_labels(question_id_sublist[i], label_list, f)
    f.flush()

# get label using logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号